LGMay 26, 2022

Trainable Weight Averaging: Accelerating Training and Improving Generalization

arXiv:2205.13104v42 citationsh-index: 14Has Code
Originality Incremental advance
AI Analysis

This work addresses the need for more flexible and efficient weight averaging techniques in deep learning, offering incremental improvements over stochastic weight averaging for practitioners training large-scale models.

The paper tackles the problem of suboptimal weight averaging in deep neural networks by introducing Trainable Weight Averaging (TWA), which learns optimal weighting coefficients, resulting in over 40% faster training on CIFAR and 30% on ImageNet while maintaining performance and improving generalization compared to existing methods.

Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA's advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40\% on CIFAR datasets and 30\% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available at https://github.com/nblt/TWA.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes