Dissecting the Interplay of Attention Paths in a Statistical Mechanics Theory of Transformers

Harvard
arXiv:2405.15926v210 citationsh-index: 82NIPS
Originality Highly original
AI Analysis

This provides a foundational theoretical framework for analyzing attention mechanisms in Transformers, which is incremental but offers insights for model optimization like pruning.

The paper tackles the theoretical understanding of Transformers by developing a statistical mechanics theory for a tractable multi-head self-attention network, showing that predictor statistics decompose into kernels from attention paths, which enhances generalization, as confirmed in experiments on synthetic and real-world tasks.

Despite the remarkable empirical performance of Transformers, their theoretical understanding remains elusive. Here, we consider a deep multi-head self-attention network, that is closely related to Transformers yet analytically tractable. We develop a statistical mechanics theory of Bayesian learning in this model, deriving exact equations for the network's predictor statistics under the finite-width thermodynamic limit, i.e., $N,P\rightarrow\infty$, $P/N=\mathcal{O}(1)$, where $N$ is the network width and $P$ is the number of training examples. Our theory shows that the predictor statistics are expressed as a sum of independent kernels, each one pairing different 'attention paths', defined as information pathways through different attention heads across layers. The kernels are weighted according to a 'task-relevant kernel combination' mechanism that aligns the total kernel with the task labels. As a consequence, this interplay between attention paths enhances generalization performance. Experiments confirm our findings on both synthetic and real-world sequence classification tasks. Finally, our theory explicitly relates the kernel combination mechanism to properties of the learned weights, allowing for a qualitative transfer of its insights to models trained via gradient descent. As an illustration, we demonstrate an efficient size reduction of the network, by pruning those attention heads that are deemed less relevant by our theory.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes