mSAM: Micro-Batch-Averaged Sharpness-Aware Minimization
This work addresses generalization issues in deep learning for practitioners, offering an incremental improvement over SAM with flexible implementation.
The paper tackles the problem of improving generalization in over-parameterized deep learning models by proposing mSAM, a variant of Sharpness-Aware Minimization that aggregates adversarial perturbations across micro-batches, achieving flatter minima and superior generalization performance compared to SAM across various image classification and NLP tasks.
Modern deep learning models are over-parameterized, where different optima can result in widely varying generalization performance. The Sharpness-Aware Minimization (SAM) technique modifies the fundamental loss function that steers gradient descent methods toward flatter minima, which are believed to exhibit enhanced generalization prowess. Our study delves into a specific variant of SAM known as micro-batch SAM (mSAM). This variation involves aggregating updates derived from adversarial perturbations across multiple shards (micro-batches) of a mini-batch during training. We extend a recently developed and well-studied general framework for flatness analysis to theoretically show that SAM achieves flatter minima than SGD, and mSAM achieves even flatter minima than SAM. We provide a thorough empirical evaluation of various image classification and natural language processing tasks to substantiate this theoretical advancement. We also show that contrary to previous work, mSAM can be implemented in a flexible and parallelizable manner without significantly increasing computational costs. Our implementation of mSAM yields superior generalization performance across a wide range of tasks compared to SAM, further supporting our theoretical framework.