Understanding Factual Recall in Transformers via Associative Memories
This work provides theoretical insights into the mechanisms of factual recall in transformers, which is incremental but clarifies a known bottleneck for AI researchers.
The authors tackled the problem of how transformers perform factual recall by showing that shallow transformers can achieve near-optimal storage capacity using associative memories, with proofs of 100% accuracy on a synthetic task when parameters scale linearly with the number of facts.
Large language models have demonstrated an impressive ability to perform factual recall. Prior work has found that transformers trained on factual recall tasks can store information at a rate proportional to their parameter count. In our work, we show that shallow transformers can use a combination of associative memories to obtain such near optimal storage capacity. We begin by proving that the storage capacities of both linear and MLP associative memories scale linearly with parameter count. We next introduce a synthetic factual recall task, and prove that a transformer with a single layer of self-attention followed by an MLP can obtain 100% accuracy on the task whenever either the total number of self-attention parameters or MLP parameters scales (up to log factors) linearly with the number of facts. In particular, the transformer can trade off between using the value matrices or the MLP as an associative memory to store the dataset of facts. We complement these expressivity results with an analysis of the gradient flow trajectory of a simplified linear attention model trained on our factual recall task, where we show that the model exhibits sequential learning behavior.