Transformers as Measure-Theoretic Associative Memory: A Statistical Perspective and Minimax Optimality
This provides a principled theoretical foundation for designing and analyzing Transformers that handle long, distributional contexts, which is incremental but offers new guarantees for machine learning researchers.
The paper tackles the problem of understanding Transformers as associative memory by recasting them in a measure-theoretic framework, showing that a shallow Transformer with an MLP can learn recall-and-predict maps with provable minimax optimal convergence rates.
Transformers excel through content-addressable retrieval and the ability to exploit contexts of, in principle, unbounded length. We recast associative memory at the level of probability measures, treating a context as a distribution over tokens and viewing attention as an integral operator on measures. Concretely, for mixture contexts $ν= I^{-1} \sum_{i=1}^I μ^{(i^*)}$ and a query $x_{\mathrm{q}}(i^*)$, the task decomposes into (i) recall of the relevant component $μ^{(i^*)}$ and (ii) prediction from $(μ_{i^*},x_\mathrm{q})$. We study learned softmax attention (not a frozen kernel) trained by empirical risk minimization and show that a shallow measure-theoretic Transformer composed with an MLP learns the recall-and-predict map under a spectral assumption on the input densities. We further establish a matching minimax lower bound with the same rate exponent (up to multiplicative constants), proving sharpness of the convergence order. The framework offers a principled recipe for designing and analyzing Transformers that recall from arbitrarily long, distributional contexts with provable generalization guarantees.