LGAIDCOct 5, 2023

DISTFLASHATTN: Distributed Memory-efficient Attention for Long-context LLMs Training

CMU
arXiv:2310.03294v20.4838 citationsh-index: 40Has Code
AI Analysis55

This work addresses the computational bottleneck of scaling LLM training to longer sequences, which is crucial for applications requiring extensive context, but it is incremental as it builds on FlashAttention with distributed optimizations.

The paper tackles the problem of training long-context large language models (LLMs) by introducing DISTFLASHATTN, a distributed memory-efficient attention mechanism that achieves up to 8x longer sequences and speedups of 4.45-5.64x compared to existing methods like Ring Self-Attention.

FlashAttention (Dao, 2023) effectively reduces the quadratic peak memory usage to linear in training transformer-based large language models (LLMs) on a single GPU. In this paper, we introduce DISTFLASHATTN, a distributed memory-efficient attention mechanism optimized for long-context LLMs training. We propose three key techniques: token-level workload balancing, overlapping key-value communication, and a rematerialization-aware gradient checkpointing algorithm. We evaluate DISTFLASHATTN on Llama-7B and variants with sequence lengths from 32K to 512K. DISTFLASHATTN achieves 8x longer sequences, 4.45 - 5.64x speedup compared to Ring Self-Attention, 2 - 8x longer sequences, 1.24 - 2.01x speedup compared to Megatron-LM with FlashAttention. It achieves 1.67x and 1.26 - 1.88x speedup compared to recent Ring Attention and DeepSpeed-Ulysses. Code is available at https://github.com/RulinShao/LightSeq.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes