When, Where and Why to Average Weights?
This research tackles the problem of improving the generalization performance and reducing training time of Machine Learning models, which is significant for the broader machine learning community, particularly those working with deep learning models.
The authors investigated the effectiveness of averaging checkpoints along the training trajectory and found that it significantly accelerates training, yielding considerable efficiency gains with minimal implementation and memory cost, and mildly improves generalization. The evaluation across seven architectures and datasets showed that averaging can replace learning rate decay and achieve the best performances when optimally combined with learning rate annealing.
Averaging checkpoints along the training trajectory is a simple yet powerful approach to improve the generalization performance of Machine Learning models and reduce training time. Motivated by these potential gains, and in an effort to fairly and thoroughly benchmark this technique, we present an extensive evaluation of averaging techniques in modern Deep Learning, which we perform using AlgoPerf \citep{dahl_benchmarking_2023}, a large-scale benchmark for optimization algorithms. We investigate whether weight averaging can reduce training time, improve generalization, and replace learning rate decay, as suggested by recent literature. Our evaluation across seven architectures and datasets reveals that averaging significantly accelerates training and yields considerable efficiency gains, at the price of a minimal implementation and memory cost, while mildly improving generalization across all considered workloads. Finally, we explore the relationship between averaging and learning rate annealing and show how to optimally combine the two to achieve the best performances.