LGMLFeb 6, 2021

Weight Rescaling: Effective and Robust Regularization for Deep Neural Networks with Batch Normalization

arXiv:2102.03497v22 citations
AI Analysis

This work addresses the practical challenges of regularization for deep neural networks with batch normalization, providing a more robust and less hyperparameter-sensitive alternative to weight decay for practitioners.

This paper identifies issues with weight decay in deep neural networks with batch normalization (BN-DNNs), specifically an increasing effective learning rate with non-adaptive optimizers leading to overfitting, and hyperparameter sensitivity across various optimizers. They propose Weight Rescaling (WRS) as an alternative, which explicitly rescales weight norms to unit norm, demonstrating its effectiveness and robustness across various computer vision tasks compared to weight decay and other methods.

Weight decay is often used to ensure good generalization in the training practice of deep neural networks with batch normalization (BN-DNNs), where some convolution layers are invariant to weight rescaling due to the normalization. In this paper, we demonstrate that the practical usage of weight decay still has some unsolved problems in spite of existing theoretical work on explaining the effect of weight decay in BN-DNNs. On the one hand, when the non-adaptive learning rate e.g. SGD with momentum is used, the effective learning rate continues to increase even after the initial training stage, which leads to an overfitting effect in many neural architectures. On the other hand, in both SGDM and adaptive learning rate optimizers e.g. Adam, the effect of weight decay on generalization is quite sensitive to the hyperparameter. Thus, finding an optimal weight decay parameter requires extensive parameter searching. To address those weaknesses, we propose to regularize the weight norm using a simple yet effective weight rescaling (WRS) scheme as an alternative to weight decay. WRS controls the weight norm by explicitly rescaling it to the unit norm, which prevents a large increase to the gradient but also ensures a sufficiently large effective learning rate to improve generalization. On a variety of computer vision applications including image classification, object detection, semantic segmentation and crowd counting, we show the effectiveness and robustness of WRS compared with weight decay, implicit weight rescaling (weight standardization) and gradient projection (AdamP).

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes