RAT+: Train Dense, Infer Sparse -- Recurrence Augmented Attention for Dilated Inference
This addresses the efficiency bottleneck in large language models by enabling sparse inference from dense pretraining, reducing the need for separate sparse models, though it is incremental in improving existing dilated attention methods.
The paper tackles the problem of severe accuracy degradation when sparsifying pretrained attention models to dilated patterns for inference-time efficiency, and introduces RAT+, a dense-pretraining architecture that augments attention with recurrence to enable flexible sparse inference with minimal adaptation. At 1.5B parameters, RAT+ closely matches dense accuracy at dilation 16 and drops by 2-3 points at 64 on tasks like commonsense reasoning and LongBench, while outperforming attention in top-k block sparsification.
Structured dilated attention has an appealing inference-time efficiency knob: it reduces the FLOPs of the attention and the KV cache size by a factor of the dilation size D, while preserving long-range connectivity. However, we find a persistent failure mode of them -- sparsifying a pretrained attention model to a dilated pattern leads to severe accuracy degradation. We introduce RAT+, a dense-pretraining architecture that augments attention with full-sequence recurrence and active recurrence learning. A single RAT+ model is pretrained densely once, then flexibly switched at inference time to dilated attention (optionally with local windows) or hybrid layer/head compositions, requiring only a short 1B-token resolution adaptation rather than retraining separate sparse models. At 1.5B parameters trained on 100B tokens, RAT+ closely matches dense accuracy at 16 and drops by about 2-3 points at 64 on commonsense reasoning and LongBench tasks, respectively. Moreover, RAT+ outperforms attention when sparsifying to the top-k block attention. We further scale to 2.6B parameters and 200B tokens and observe the same trend.