LGCCMLFeb 20, 2024

Chain of Thought Empowers Transformers to Solve Inherently Serial Problems

arXiv:2402.12875v4283 citationsh-index: 27ICLR
Originality Highly original
AI Analysis

This provides a theoretical explanation for CoT's effectiveness, addressing a key bottleneck in transformer expressiveness for serial computation, which is incremental but clarifies a widely used method.

The paper tackles the problem of understanding why chain of thought (CoT) improves transformer accuracy on serial tasks, showing that CoT enables constant-depth transformers with limited precision to solve problems solvable by boolean circuits of size T, whereas without CoT they are limited to AC^0. Empirically, CoT dramatically improves accuracy on tasks like permutation group composition and iterated squaring, especially for low-depth models.

Instructing the model to generate a sequence of intermediate steps, a.k.a., a chain of thought (CoT), is a highly effective method to improve the accuracy of large language models (LLMs) on arithmetics and symbolic reasoning tasks. However, the mechanism behind CoT remains unclear. This work provides a theoretical understanding of the power of CoT for decoder-only transformers through the lens of expressiveness. Conceptually, CoT empowers the model with the ability to perform inherently serial computation, which is otherwise lacking in transformers, especially when depth is low. Given input length $n$, previous works have shown that constant-depth transformers with finite precision $\mathsf{poly}(n)$ embedding size can only solve problems in $\mathsf{TC}^0$ without CoT. We first show an even tighter expressiveness upper bound for constant-depth transformers with constant-bit precision, which can only solve problems in $\mathsf{AC}^0$, a proper subset of $ \mathsf{TC}^0$. However, with $T$ steps of CoT, constant-depth transformers using constant-bit precision and $O(\log n)$ embedding size can solve any problem solvable by boolean circuits of size $T$. Empirically, enabling CoT dramatically improves the accuracy for tasks that are hard for parallel computation, including the composition of permutation groups, iterated squaring, and circuit value problems, especially for low-depth transformers.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes