Rennala MVR: Improved Time Complexity for Parallel Stochastic Optimization via Momentum-Based Variance Reduction
For practitioners training large models on heterogeneous clusters, this work improves the time complexity of parallel stochastic optimization, though the gains are incremental over existing methods.
The paper proposes Rennala MVR, a variance-reduced extension of Rennala SGD, and shows that under mean-squared smoothness, variance reduction improves time complexity in heterogeneous parallel environments. Experiments on stochastic quadratics and neural networks demonstrate empirical gains over Rennala SGD.
Large-scale machine learning models are trained on clusters of machines that exhibit heterogeneous performance due to hardware variability, network delays, and system-level instabilities. In such environments, time complexity rather than iteration complexity becomes the relevant performance metric for optimization algorithms. Recent work by Tyurin and Richtárik (2023) established the first time complexity analysis for parallel first-order stochastic optimization, proposing Rennala SGD as a time-optimal method for smooth nonconvex optimization. However, Rennala SGD is fundamentally a modification of SGD, and variance reduction techniques are known to improve the iteration complexity of SGD. In this work, we investigate whether variance reduction can also improve time complexity in heterogeneous systems. We show that, under a mean-squared smoothness assumption, variance reduction can improve time complexity in relevant parameter regimes. To this end, we propose Rennala MVR, a variance-reduced extension of Rennala SGD based on momentum-based variance reduction, and analyze its oracle and time complexity. We establish lower bounds for time complexity under these assumptions. On a stochastic quadratic benchmark, experiments with the exact method support the theory, while neural-network experiments with a practical inexact variant show similar empirical gains over Rennala SGD.