Inference-time sparse attention with asymmetric indexing
This addresses the computational bottleneck of self-attention for long-context language models, offering a significant speed improvement.
The paper tackles the problem of speeding up self-attention in transformer models by introducing Saap, an asymmetrical indexing technique that reduces memory look-up by a factor of 20 and achieves a 60% time saving compared to FlashAttention-v2 on long sequences.
Self-attention in transformer models is an incremental associative memory that maps key vectors to value vectors. One way to speed up self-attention is to employ GPU-compatible vector search algorithms based on standard partitioning methods such as k-means. However, such partitioning methods yield poor results in this context because (1) the keys and queries follow different distributions, and (2) the RoPE positional encoding hinders the bucket assignment. This paper introduces Saap (Self-Attention with Asymmetric Partitions), which overcomes these problems. It is an asymmetrical indexing technique that employs distinct partitions for keys and queries, thereby approximating self-attention with a data-adaptive sparsity pattern. It works on pretrained language models and only requires to train (offline) a small query classifier. On a long context Llama 3.1-8b model, with sequences ranging from 100k to 500k tokens, Saap typically reduces by a factor of 20 the fraction of memory that needs to be looked-up, which translates to a time saving of 60\% when compared to FlashAttention-v2.