Efficient Long-context Language Model Training by Core Attention Disaggregation
This addresses efficiency issues for researchers and practitioners training large language models with long contexts, though it is incremental as it optimizes an existing bottleneck rather than introducing a new paradigm.
The paper tackles the problem of load imbalance in long-context large language model training caused by the quadratic compute growth of core attention, and presents core attention disaggregation (CAD) to decouple and schedule it separately, resulting in up to 1.35x improved training throughput and near-perfect balance on 512 GPUs with contexts up to 512k tokens.
We present core attention disaggregation (CAD), a technique that improves long-context large language model training by decoupling the core attention computation, softmax(QK^T)V, from the rest of the model and executing it on a separate pool of devices. In existing systems, core attention is colocated with other layers; at long context lengths, its quadratic compute growth compared to the near-linear growth of other components causes load imbalance and stragglers across data and pipeline parallel groups. CAD is enabled by two observations. First, core attention is stateless: it has no trainable parameters and only minimal transient data, so balancing reduces to scheduling compute-bound tasks. Second, it is composable: modern attention kernels retain high efficiency when processing fused batches of token-level shards with arbitrary lengths. CAD partitions core attention into token-level tasks and dispatches them to dedicated attention servers, which dynamically rebatch tasks to equalize compute without sacrificing kernel efficiency. We implement CAD in a system called DistCA, which uses a ping-pong execution scheme to fully overlap communication with computation and in-place execution on attention servers to reduce memory use. On 512 H200 GPUs and context lengths up to 512k tokens, DistCA improves end-to-end training throughput by up to 1.35x, eliminates data and pipeline parallel stragglers, and achieves near-perfect compute and memory balance.