JoMA: Demystifying Multilayer Transformers via JOint Dynamics of MLP and Attention
This work provides a theoretical foundation for interpreting Transformer training, which is incremental but addresses a key bottleneck in AI interpretability.
The authors tackled the problem of understanding training dynamics in multilayer Transformers by proposing JoMA, a mathematical framework that integrates out self-attention to analyze MLP layers, predicting attention sparsity patterns and verifying findings with experiments on real-world datasets and pre-trained models.
We propose Joint MLP/Attention (JoMA) dynamics, a novel mathematical framework to understand the training procedure of multilayer Transformer architectures. This is achieved by integrating out the self-attention layer in Transformers, producing a modified dynamics of MLP layers only. JoMA removes unrealistic assumptions in previous analysis (e.g., lack of residual connection) and predicts that the attention first becomes sparse (to learn salient tokens), then dense (to learn less salient tokens) in the presence of nonlinear activations, while in the linear case, it is consistent with existing works that show attention becomes sparse over time. We leverage JoMA to qualitatively explains how tokens are combined to form hierarchies in multilayer Transformers, when the input tokens are generated by a latent hierarchical generative model. Experiments on models trained from real-world dataset (Wikitext2/Wikitext103) and various pre-trained models (OPT, Pythia) verify our theoretical findings. Code can be found in https://github.com/facebookresearch/luckmatters/tree/yuandong3.