Residual Matrix Transformers: Scaling the Size of the Residual Stream
This addresses a bottleneck in transformer scaling for AI researchers and practitioners, offering a more efficient alternative with concrete performance gains.
The paper tackles the inefficiency of the residual stream in transformers by replacing it with an outer product memory matrix, resulting in a model that achieves the same loss with 58% fewer FLOPS, 25% fewer parameters, and 41% fewer training tokens while outperforming on downstream tasks.
The residual stream acts as a memory bus where transformer layers both store and access features (Elhage et al., 2021). We consider changing the mechanism for retrieving and storing information in the residual stream, and replace the residual stream of the transformer with an outer product memory matrix (Kohonen, 1972, Anderson, 1972). We call this model the Residual Matrix Transformer (RMT). We find that the RMT enjoys a number of attractive properties: 1) the size of the residual stream can be scaled independently of compute and model size, improving performance, 2) the RMT can achieve the same loss as the transformer with 58% fewer FLOPS, 25% fewer parameters, and 41% fewer training tokens tokens, and 3) the RMT outperforms the transformer on downstream evaluations. We theoretically analyze the transformer and the RMT, and show that the RMT allows for more efficient scaling of the residual stream, as well as improved variance propagation properties. Code for this project can be found at https://github.com/bmac3/residual-matrix-transformer.