Slim attention: cut your context memory in half without loss -- K-cache is all you need for MHA
This addresses memory bottlenecks for deploying large transformer models, though it is incremental as it builds on existing attention mechanisms.
The paper tackles the problem of high memory usage in transformer models with multi-head attention by introducing Slim Attention, which reduces context memory size by 2x without accuracy loss, achieving up to 2x speedup in inference for large contexts.
Slim attention shrinks the context memory size by 2x for transformer models with MHA (multi-head attention), which can speed up inference by up to 2x for large context windows. Slim attention is an exact, mathematically identical implementation of the standard attention mechanism and therefore doesn't compromise model accuracy. In other words, slim attention losslessly compresses the context memory by a factor of 2. For encoder-decoder transformers, the context memory size can be reduced even further: For the Whisper models for example, slim attention reduces the context memory by 8x, which can speed up token generation by 5x for batch size 64 for example. And for the T5-11B model for example, the memory can be reduced by 32x because its MHA projection dimension is larger than the embedding dimension. See https://github.com/OpenMachine-ai/transformer-tricks for code and more transformer tricks, and https://www.youtube.com/watch?v=uVtk3B6YO4Y for this paper's YouTube video.