LGAIFeb 28, 2025

Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought

arXiv:2502.21212v134 citationsh-index: 5ICLR
Originality Highly original
AI Analysis

This provides theoretical insights into the mechanisms of CoT training for in-context learning, which is incremental but clarifies a specific bottleneck in transformer expressivity.

The paper tackled the problem of understanding how transformers learn to perform multi-step gradient descent with Chain of Thought prompting, showing that transformers with CoT can achieve near-exact recovery of ground-truth weights in linear regression, while those without CoT fail.

Chain of Thought (CoT) prompting has been shown to significantly improve the performance of large language models (LLMs), particularly in arithmetic and reasoning tasks, by instructing the model to produce intermediate reasoning steps. Despite the remarkable empirical success of CoT and its theoretical advantages in enhancing expressivity, the mechanisms underlying CoT training remain largely unexplored. In this paper, we study the training dynamics of transformers over a CoT objective on an in-context weight prediction task for linear regression. We prove that while a one-layer linear transformer without CoT can only implement a single step of gradient descent (GD) and fails to recover the ground-truth weight vector, a transformer with CoT prompting can learn to perform multi-step GD autoregressively, achieving near-exact recovery. Furthermore, we show that the trained transformer effectively generalizes on the unseen data. With our technique, we also show that looped transformers significantly improve final performance compared to transformers without looping in the in-context learning of linear regression. Empirically, we demonstrate that CoT prompting yields substantial performance improvements.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes