Linear attention is (maybe) all you need (to understand transformer optimization)
This work addresses the challenge of optimizing Transformers for researchers, offering a simplified model that may help demystify training complexities, though it is incremental as it builds on prior linearization studies.
The paper tackles the difficulty of Transformer training by studying a simple linearized shallow Transformer model on regression tasks, finding that it reproduces key aspects of Transformer training dynamics, suggesting it could be a valuable abstraction for understanding optimization.
Transformer training is notoriously difficult, requiring a careful design of optimizers and use of various heuristics. We make progress towards understanding the subtleties of training Transformers by carefully studying a simple yet canonical linearized shallow Transformer model. Specifically, we train linear Transformers to solve regression tasks, inspired by J.~von Oswald et al.~(ICML 2023), and K.~Ahn et al.~(NeurIPS 2023). Most importantly, we observe that our proposed linearized models can reproduce several prominent aspects of Transformer training dynamics. Consequently, the results obtained in this paper suggest that a simple linearized Transformer model could actually be a valuable, realistic abstraction for understanding Transformer optimization.