Homogenized Transformers
This work addresses representation collapse in transformers, a key issue for AI practitioners, but it is incremental as it builds on existing theoretical frameworks without introducing new methods.
The paper tackles the problem of representation collapse in deep multi-head self-attention models by analyzing a random model with resampled weights, deriving a homogenized limit that reveals quantitative trade-offs between dimension, context length, and temperature, and identifies regimes to mitigate clustering.
We study a random model of deep multi-head self-attention in which the weights are resampled independently across layers and heads, as at initialization of training. Viewing depth as a time variable, the residual stream defines a discrete-time interacting particle system on the unit sphere. We prove that, under suitable joint scalings of the depth, the residual step size, and the number of heads, this dynamics admits a nontrivial homogenized limit. Depending on the scaling, the limit is either deterministic or stochastic with common noise; in the mean-field regime, the latter leads to a stochastic nonlinear Fokker--Planck equation for the conditional law of a representative token. In the Gaussian setting, the limiting drift vanishes, making the homogenized dynamics explicit enough to study representation collapse. This yields quantitative trade-offs between dimension, context length, and temperature, and identifies regimes in which clustering can be mitigated.