GRAWA: Gradient-based Weighted Averaging for Distributed Training of Deep Learning Models
This work addresses the challenge of efficient distributed training for deep learning practitioners, offering incremental improvements in communication efficiency and optimization performance.
The authors tackled the problem of distributed training of deep learning models under time constraints by proposing GRAWA, a gradient-based weighted averaging algorithm that prioritizes flat regions in the optimization landscape, resulting in faster convergence, better quality local optima, and reduced communication frequency compared to state-of-the-art baselines.
We study distributed training of deep learning models in time-constrained environments. We propose a new algorithm that periodically pulls workers towards the center variable computed as a weighted average of workers, where the weights are inversely proportional to the gradient norms of the workers such that recovering the flat regions in the optimization landscape is prioritized. We develop two asynchronous variants of the proposed algorithm that we call Model-level and Layer-level Gradient-based Weighted Averaging (resp. MGRAWA and LGRAWA), which differ in terms of the weighting scheme that is either done with respect to the entire model or is applied layer-wise. On the theoretical front, we prove the convergence guarantee for the proposed approach in both convex and non-convex settings. We then experimentally demonstrate that our algorithms outperform the competitor methods by achieving faster convergence and recovering better quality and flatter local optima. We also carry out an ablation study to analyze the scalability of the proposed algorithms in more crowded distributed training environments. Finally, we report that our approach requires less frequent communication and fewer distributed updates compared to the state-of-the-art baselines.