Transformers Efficiently Perform In-Context Logistic Regression via Normalized Gradient Descent
For the theoretical understanding of in-context learning, this paper provides a mechanistic explanation of how softmax transformers can implement gradient descent for linear classification.
This work shows that transformers with softmax attention can perform in-context logistic regression by implicitly executing normalized gradient descent, with each layer corresponding to one step. The authors construct such transformers and provide training convergence and out-of-distribution generalization guarantees.
Transformers have demonstrated remarkable in-context learning (ICL) capabilities. The strong ICL performance of transformers is commonly believed to arise from their ability to implicitly execute certain algorithms on the context, thereby enhancing prediction and generation. In this work, we investigate how transformers with softmax attention perform in-context learning on linear classification data. We first construct a class of multi-layer transformers that can perform in-context logistic regression, with each layer exactly performing one step of normalized gradient descent on an in-context loss. Then, we show that our constructed transformer can be obtained through (i) training a single self-attention layer supervised by one-step gradient descent, and (ii) recurrently applying the trained layer to obtain a looped model. Training convergence guarantees of the self-attention layer and out-of-distribution generalization guarantees of the looped model are provided. Our results advance the theoretical understanding of ICL mechanism by showcasing how softmax transformers can effectively act as in-context learners.