One-Layer Transformer Provably Learns One-Nearest Neighbor In Context
This provides theoretical insight into transformer mechanisms for researchers in machine learning, though it is incremental as it focuses on a specific, classical algorithm.
The paper tackles the problem of understanding how transformers perform in-context learning by proving that a single-layer transformer can learn the one-nearest neighbor classification rule from prompts, demonstrating this capability theoretically under nonconvex training conditions.
Transformers have achieved great success in recent years. Interestingly, transformers have shown particularly strong in-context learning capability -- even without fine-tuning, they are still able to solve unseen tasks well purely based on task-specific prompts. In this paper, we study the capability of one-layer transformers in learning one of the most classical nonparametric estimators, the one-nearest neighbor prediction rule. Under a theoretical framework where the prompt contains a sequence of labeled training data and unlabeled test data, we show that, although the loss function is nonconvex when trained with gradient descent, a single softmax attention layer can successfully learn to behave like a one-nearest neighbor classifier. Our result gives a concrete example of how transformers can be trained to implement nonparametric machine learning algorithms, and sheds light on the role of softmax attention in transformer models.