Training Dynamics of Softmax Self-Attention: Fast Global Convergence via Preconditioning
This provides theoretical guarantees for training self-attention, a core component of transformers, addressing optimization challenges in deep learning.
The authors tackled the problem of training softmax self-attention layers for linear regression by analyzing gradient descent dynamics, showing that a novel first-order optimization algorithm converges to globally optimal parameters at a geometric rate.
We study the training dynamics of gradient descent in a softmax self-attention layer trained to perform linear regression and show that a simple first-order optimization algorithm can converge to the globally optimal self-attention parameters at a geometric rate. Our analysis proceeds in two steps. First, we show that in the infinite-data limit the regression problem solved by the self-attention layer is equivalent to a nonconvex matrix factorization problem. Second, we exploit this connection to design a novel "structure-aware" variant of gradient descent which efficiently optimizes the original finite-data regression objective. Our optimization algorithm features several innovations over standard gradient descent, including a preconditioner and regularizer which help avoid spurious stationary points, and a data-dependent spectral initialization of parameters which lie near the manifold of global minima with high probability.