The Implicit Bias of Steepest Descent with Mini-batch Stochastic Gradient
This work addresses the problem of understanding when stochastic optimization aligns with full-batch behavior for researchers and practitioners in machine learning, providing incremental insights into batch size, momentum, and variance reduction effects.
The paper tackles the implicit bias of mini-batch stochastic steepest descent in multi-class classification, showing that without momentum, convergence requires large batches with a batch-dependent margin gap, while momentum enables small-batch convergence through a trade-off that slows convergence, and variance reduction recovers full-batch bias for any batch size at a slower rate.
A variety of widely used optimization methods like SignSGD and Muon can be interpreted as instances of steepest descent under different norm-induced geometries. In this work, we study the implicit bias of mini-batch stochastic steepest descent in multi-class classification, characterizing how batch size, momentum, and variance reduction shape the limiting max-margin behavior and convergence rates under general entry-wise and Schatten-$p$ norms. We show that without momentum, convergence only occurs with large batches, yielding a batch-dependent margin gap but the full-batch convergence rate. In contrast, momentum enables small-batch convergence through a batch-momentum trade-off, though it slows convergence. This approach provides fully explicit, dimension-free rates that improve upon prior results. Moreover, we prove that variance reduction can recover the exact full-batch implicit bias for any batch size, albeit at a slower convergence rate. Finally, we further investigate the batch-size-one steepest descent without momentum, and reveal its convergence to a fundamentally different bias via a concrete data example, which reveals a key limitation of purely stochastic updates. Overall, our unified analysis clarifies when stochastic optimization aligns with full-batch behavior, and paves the way for perform deeper explorations of the training behavior of stochastic gradient steepest descent algorithms.