MLLGOCFeb 20, 2024

Tracking the Median of Gradients with a Stochastic Proximal Point Method

arXiv:2402.12828v22 citationsh-index: 9Trans. Mach. Learn. Res.
AI Analysis

This work addresses robustness in stochastic optimization for domains like distributed learning with corrupted nodes or heavy-tailed noise, but it is incremental as it builds on existing methods like clipping and momentum.

The paper tackles the problem of robust gradient estimation in stochastic optimization by proposing a method to track the median of gradients using a stochastic proximal point approach, showing that it can converge under heavy-tailed, state-dependent noise.

There are several applications of stochastic optimization where one can benefit from a robust estimate of the gradient. For example, domains such as distributed learning with corrupted nodes, the presence of large outliers in the training data, learning under privacy constraints, or even heavy-tailed noise due to the dynamics of the algorithm itself. Here we study SGD with robust gradient estimators based on estimating the median. We first derive iterative methods based on the stochastic proximal point method for computing the median gradient and generalizations thereof. Then we propose an algorithm estimating the median gradient across iterations, and find that several well known methods are particular cases of this framework. For instance, we observe that different forms of clipping allow to compute online estimators of the median of gradients, in contrast to (heavy-ball) momentum, which corresponds to an online estimator of the mean. Finally, we provide a theoretical framework for an algorithm computing the median gradient across samples, and show that the resulting method can converge even under heavy-tailed, state-dependent noise.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes