Critical attention scaling in long-context transformers
This addresses a fundamental bottleneck in scaling transformers to longer contexts, which is crucial for improving large language models, though it is incremental as it builds on existing attention scaling techniques.
The paper tackled the problem of attention collapse in long-context transformers, where attention scores become uniform as context length increases, and found that a critical scaling factor of approximately log n prevents this collapse, justifying methods used in models like YaRN and Qwen.
As large language models scale to longer contexts, attention layers suffer from a fundamental pathology: attention scores collapse toward uniformity as context length $n$ increases, causing tokens to cluster excessively, a phenomenon known as rank-collapse. While $\textit{attention scaling}$ effectively addresses this deficiency by rescaling attention scores with a polylogarithmic factor $β_n$, theoretical justification for this approach remains lacking. We analyze a simplified yet tractable model that magnifies the effect of attention scaling. In this model, attention exhibits a phase transition governed by the scaling factor $β_n$: insufficient scaling collapses all tokens to a single direction, while excessive scaling reduces attention to identity, thereby eliminating meaningful interactions between tokens. Our main result identifies the critical scaling $β_n \asymp \log n$ and provides a rigorous justification for attention scaling in YaRN and Qwen, clarifying why logarithmic scaling maintains sparse, content-adaptive attention at large context lengths.