On the Runway Cascade of Transformers for Language Modeling
This addresses failure modes in causal transformers for language modeling, offering a parameter-free enhancement that improves performance in specific tasks, though it is incremental as it builds on existing transformer architectures.
The paper tackled the problem of misaligned information propagation in causal transformers, which leads to redundancies and irrelevant information cascading to token representations, and proposed runway-aware rewiring to incorporate runway context into attention, resulting in steady improvements in language modeling and stronger information retrieval and extrapolation abilities.
In decoder-only (causal) transformers, the computation graph created by causal masking routes information through both direct-path attention and indirect paths formed by intermediate tokens. We denote these indirect paths between token pairs as their runways. We argue that certain failure modes of causal transformers as observed by a growing body of recent works are likely exacerbated by a misalignment between these two information propagation modes. We formalize runway cascade as a phenomenon whereby this misalignment results in redundancies and irrelevant information cascading to token representations despite adequately learned attention patterns. As a solution, we propose runway-aware rewiring as a more explicit way of incorporating runway context directly into each token's direct-path attention. This mechanism re-wires the attention pattern for each token based on a summary of its runway landscape, enabling awareness of accumulating representational influences and allowing for more balanced information propagation. Our proposed methodology introduces no additional parameters and can seamlessly be integrated into standard attention mechanism. Empirically, our rewired transformer results in steady improvements in general language modeling as well as noticeably stronger information retrieval and extrapolation abilities compared to standard transformers.