Two Heads Are Better than One: Simulating Large Transformers with Small Ones
This addresses the scaling problem for transformers in machine learning, offering a theoretical solution to handle long sequences more efficiently, though it is incremental as it builds on existing transformer architectures.
The paper tackles the quadratic complexity of self-attention in transformers for long input sequences by proving that large transformers can be efficiently simulated by multiple small transformers, reducing the required number from O((N/M)^2) in worst-case to O(N/M) in natural scenarios like average-case inputs.
The quadratic complexity of self-attention prevents transformers from scaling effectively to long input sequences. On the other hand, modern GPUs and other specialized hardware accelerators are well-optimized for processing small input sequences in transformers during both training and inference. A natural question arises: can we take advantage of the efficiency of small transformers to deal with long input sequences? In this paper, we show that transformers with long input sequences (large transformers) can be efficiently simulated by transformers that can only take short input sequences (small transformers). Specifically, we prove that any transformer with input length $N$ can be efficiently simulated by only $O((N/M)^2)$ transformers with input length $M \ll N$, and that this cannot be improved in the worst case. However, we then prove that in various natural scenarios including average-case inputs, sliding window masking and attention sinks, the optimal number $O(N/M)$ of small transformers suffice.