The Crucial Role of Normalization in Sharpness-Aware Minimization
This work provides insights into a key component of SAM, an optimizer that improves deep neural network performance, but it is incremental as it builds on existing theoretical explanations.
The paper investigates the role of normalization in Sharpness-Aware Minimization (SAM), finding that it stabilizes the algorithm and enables drifting along minima manifolds, which contributes to better performance and robustness in hyper-parameter tuning.
Sharpness-Aware Minimization (SAM) is a recently proposed gradient-based optimizer (Foret et al., ICLR 2021) that greatly improves the prediction performance of deep neural networks. Consequently, there has been a surge of interest in explaining its empirical success. We focus, in particular, on understanding the role played by normalization, a key component of the SAM updates. We theoretically and empirically study the effect of normalization in SAM for both convex and non-convex functions, revealing two key roles played by normalization: i) it helps in stabilizing the algorithm; and ii) it enables the algorithm to drift along a continuum (manifold) of minima -- a property identified by recent theoretical works that is the key to better performance. We further argue that these two properties of normalization make SAM robust against the choice of hyper-parameters, supporting the practicality of SAM. Our conclusions are backed by various experiments.