Dynamics of Stochastic Momentum with Sparse Updates in High Dimensions
For practitioners using momentum in deep learning with sparse gradients (e.g., transformers), the paper reveals that global momentum can be suboptimal due to conflicting dynamics across different token frequencies.
The paper analyzes momentum dynamics under sparse updates, showing that the phase structure depends on the ratio of momentum retention to learning timescales, with classical heavy-ball dynamics recovered only when these timescales coincide. It identifies a spectral conflict for global momentum across token frequencies.
Existing theory of momentum assumes that gradients arrive at every parameter at a roughly constant rate, an assumption violated in practice by heavy-tailed data distributions and modern architectures. We theoretically analyze the dynamics of two tractable models of momentum under sparse updates: a least squares model with sparse inputs and a logistic regression model with a rare class. Both admit exact closed-form second-moment dynamics whose high-dimensional limits we characterize across three scaling exponents for sparsity, batch size, and momentum decay. The phase structure on both problems is governed by the ratio of two intrinsic timescales: a momentum retention timescale (how many active updates the buffer survives) and a learning timescale (how many active updates it takes to reduce the squared error). When learning is much slower than retention, the limit matches SGD; when learning is faster, the system is unstable; where the timescales coincide, we recover classical heavy-ball dynamics. The oscillatory dynamics occur at different momentum values for different token sparsity, creating a spectral conflict for global momentum across token frequencies.