Variational Linear Attention: Stable Associative Memory for Long-Context Transformers
For practitioners of long-context Transformers, VLA offers a stable linear attention variant that mitigates memory interference while maintaining computational efficiency.
Variational Linear Attention (VLA) reframes linear attention memory updates as an online regularized least-squares problem, achieving a 109× reduction in state norm at T=1000 and near-perfect associative recall within per-head memory limits, with a Triton kernel providing 14× speedup and crossing softmax attention latency at ~43k tokens.
Linear attention reduces the quadratic cost of softmax attention to $\mathcal{O}(T)$, but its memory state grows as $\mathcal{O}(T)$ in Frobenius norm, causing progressive interference between stored associations. We introduce \textbf{Variational Linear Attention} (VLA), which reframes the memory update as an online regularised least-squares problem with an adaptive penalty matrix maintained via the Sherman-Morrison rank-1 formula. We prove that normalising the write direction to unit length gives the recurrence Jacobian spectral norm exactly $1$ for all sequence lengths and head dimensions (Proposition 2), and that the state norm is self-limiting under bounded inputs (Proposition 1). Empirically, VLA reduces $\|S_t\|_F$ by $109\times$ relative to standard linear attention at $T{=}1{,}000$, achieves near-perfect exact-match accuracy on multi-query associative recall within the effective per-head memory regime ($n_\text{pairs} < d_h$), maintaining substantially higher retrieval performance than DeltaNet and standard linear attention under increasing memory load, and maintains 62\% accuracy at the per-head capacity boundary. A Triton-fused kernel achieves $14\times$ speedup over sequential Python and $\mathcal{O}(T)$ scaling, crossing below softmax attention latency at approximately 43\,000 tokens.