LGAICLMay 21, 2025

SUS backprop: linear backpropagation algorithm for long inputs in transformers

arXiv:2505.15080v2
Originality Incremental advance
AI Analysis

This addresses the bottleneck of training transformers on long sequences, offering a linear-time backpropagation method that is incremental but practically impactful.

The paper tackles the quadratic computational cost of backpropagation in transformer attention for long sequences by proposing a stochastic gradient estimator that cuts most attention weights, reducing complexity from O(n^2) to O(nc). It shows that cutting 99% of the gradient flow for sequences of length ~2000 increases gradient variance by only about 1%.

It is straightforward to design an unbiased gradient estimator that stochastically cuts the backpropagation flow through any part of a computational graph. By cutting the parts that have little effect on the computation, one can potentially save a significant amount of backpropagation computation in exchange for a minimal increase in the stochastic gradient variance, in some situations. Such a situation occurs in the attention mechanism of the transformer architecture. For long sequences, attention becomes the limiting factor, as its compute requirements increase quadratically with sequence length $n$. At the same time, most attention weights become very small, as most attention heads tend to connect a given token with only a small fraction of other tokens in the sequence. These weights become promising targets for cutting backpropagation. We propose a simple probabilistic rule controlled by a single parameter $c$ that cuts back-propagation through most attention weights, leaving at most $c$ interactions per token per attention head. This brings a factor of $c/n$ reduction in the compute required for the attention backpropagation, turning it from quadratic $O(n^2)$ to linear complexity $O(nc)$. We have empirically verified that, for a typical transformer model, cutting about $99\%$ of the attention gradient flow (i.e. choosing $c \sim 25-30$) results in relative gradient variance increase of only about $1\%$ for $n \sim 2000$, and it decreases with $n$. This approach is amenable to efficient sparse matrix implementation, thus being promising for making the cost of a backward pass negligible relative to the cost of a forward pass when training a transformer model on long sequences.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes