How Transformers Utilize Multi-Head Attention in In-Context Learning? A Case Study on Sparse Linear Regression
This work provides incremental insights into transformer mechanisms for researchers in machine learning, focusing on multi-head attention utilization in in-context learning.
The study investigated how trained multi-head transformers perform in-context learning on sparse linear regression, finding that multiple heads are essential in the first layer for preprocessing, while subsequent layers use only one head for optimization, and this approach outperforms naive gradient descent and ridge regression.
Despite the remarkable success of transformer-based models in various real-world tasks, their underlying mechanisms remain poorly understood. Recent studies have suggested that transformers can implement gradient descent as an in-context learner for linear regression problems and have developed various theoretical analyses accordingly. However, these works mostly focus on the expressive power of transformers by designing specific parameter constructions, lacking a comprehensive understanding of their inherent working mechanisms post-training. In this study, we consider a sparse linear regression problem and investigate how a trained multi-head transformer performs in-context learning. We experimentally discover that the utilization of multi-heads exhibits different patterns across layers: multiple heads are utilized and essential in the first layer, while usually only a single head is sufficient for subsequent layers. We provide a theoretical explanation for this observation: the first layer preprocesses the context data, and the following layers execute simple optimization steps based on the preprocessed context. Moreover, we demonstrate that such a preprocess-then-optimize algorithm can significantly outperform naive gradient descent and ridge regression algorithms. Further experimental results support our explanations. Our findings offer insights into the benefits of multi-head attention and contribute to understanding the more intricate mechanisms hidden within trained transformers.