In-Context Learning of a Linear Transformer Block: Benefits of the MLP Component and One-Step GD Initialization
This work addresses the theoretical understanding of in-context learning mechanisms in transformers, specifically highlighting the importance of MLP components for reducing error in linear regression tasks, which is incremental but provides foundational insights for model design.
The paper tackles the problem of in-context learning for linear regression with a Gaussian prior, showing that a Linear Transformer Block (LTB) combining linear attention and MLP components achieves nearly Bayes optimal risk, while linear attention alone incurs irreducible error. It establishes a correspondence between LTB and one-step gradient descent estimators with learnable initialization, revealing that LTB implements these estimators to reduce approximation error.
We study the \emph{in-context learning} (ICL) ability of a \emph{Linear Transformer Block} (LTB) that combines a linear attention component and a linear multi-layer perceptron (MLP) component. For ICL of linear regression with a Gaussian prior and a \emph{non-zero mean}, we show that LTB can achieve nearly Bayes optimal ICL risk. In contrast, using only linear attention must incur an irreducible additive approximation error. Furthermore, we establish a correspondence between LTB and one-step gradient descent estimators with learnable initialization ($\mathsf{GD}\text{-}\mathbfβ$), in the sense that every $\mathsf{GD}\text{-}\mathbfβ$ estimator can be implemented by an LTB estimator and every optimal LTB estimator that minimizes the in-class ICL risk is effectively a $\mathsf{GD}\text{-}\mathbfβ$ estimator. Finally, we show that $\mathsf{GD}\text{-}\mathbfβ$ estimators can be efficiently optimized with gradient flow, despite a non-convex training objective. Our results reveal that LTB achieves ICL by implementing $\mathsf{GD}\text{-}\mathbfβ$, and they highlight the role of MLP layers in reducing approximation error.