WildCat: Near-Linear Attention in Theory and Practice
This addresses efficiency bottlenecks for deploying attention-based models in resource-constrained settings like image generation and language modeling, offering a practical improvement over prior approximations.
The paper tackles the high computational cost of attention mechanisms in neural networks, which scale quadratically with input length, by introducing WildCat, a method that approximates attention using a weighted coreset selected via randomly pivoted Cholesky, achieving near-linear runtime and super-polynomial error decay with bounded inputs.
We introduce WildCat, a high-accuracy, low-cost approach to compressing the attention mechanism in neural networks. While attention is a staple of modern network architectures, it is also notoriously expensive to deploy due to resource requirements that scale quadratically with the input sequence length $n$. WildCat avoids these quadratic costs by only attending over a small weighted coreset. Crucially, we select the coreset using a fast but spectrally-accurate subsampling algorithm -- randomly pivoted Cholesky -- and weight the elements optimally to minimise reconstruction error. Remarkably, given bounded inputs, WildCat approximates exact attention with super-polynomial $O(n^{-\sqrt{\log(\log(n))}})$ error decay while running in near-linear $O(n^{1+o(1)})$ time. In contrast, prior practical approximations either lack error guarantees or require quadratic runtime to guarantee such high fidelity. We couple this advance with a GPU-optimized PyTorch implementation and a suite of benchmark experiments demonstrating the benefits of WildCat for image generation, image classification, and language model KV cache compression.