Controlling changes to attention logits
This addresses stability issues in transformer training, particularly for specialized architectures like Multi Latent Attention, but is incremental as it builds on existing normalization techniques.
The paper tackles the problem of stabilizing transformer training by controlling changes to attention logits, showing that parameter-dependent learning rates for query and key weights allow increased base learning rates, outperform other methods in Multi Latent Attention, and achieve competitive performance with QK norm in Multi-head Attention.
Stability of neural network weights is critical when training transformer models. The query and key weights are particularly problematic, as they tend to grow large without any intervention. Applying normalization to queries and keys, known as `QK norm', fixes stability issues in practice, but is not always applicable. For example, QK norm is not compatible with Multi Latent Attention (MLA) because QK norm requires full materialization of queries and keys during inference, which is not done in MLA. In this paper we suggest that controlling the changes to logits is important for stability. We show that these changes are controllable by assigning parameter-dependent learning rates to the query and key weights. We find that our cheap intervention allows us to increase the base learning rate of the network, outperform other methods in the MLA setting, and achieve performance competitive with QK norm when using Multi-head Attention.