LGAug 15, 2024

Enhancing Sharpness-Aware Minimization by Learning Perturbation Radius

arXiv:2408.08222v12 citationsh-index: 11
Originality Incremental advance
AI Analysis

This work addresses a specific bottleneck in SAM for machine learning practitioners, offering an incremental improvement by automating radius selection.

The paper tackles the challenge of selecting an appropriate perturbation radius in Sharpness-Aware Minimization (SAM) to improve model generalization, proposing a bilevel optimization framework called LETS that learns this radius and demonstrates effectiveness across various architectures and benchmark datasets in computer vision and natural language processing.

Sharpness-aware minimization (SAM) is to improve model generalization by searching for flat minima in the loss landscape. The SAM update consists of one step for computing the perturbation and the other for computing the update gradient. Within the two steps, the choice of the perturbation radius is crucial to the performance of SAM, but finding an appropriate perturbation radius is challenging. In this paper, we propose a bilevel optimization framework called LEarning the perTurbation radiuS (LETS) to learn the perturbation radius for sharpness-aware minimization algorithms. Specifically, in the proposed LETS method, the upper-level problem aims at seeking a good perturbation radius by minimizing the squared generalization gap between the training and validation losses, while the lower-level problem is the SAM optimization problem. Moreover, the LETS method can be combined with any variant of SAM. Experimental results on various architectures and benchmark datasets in computer vision and natural language processing demonstrate the effectiveness of the proposed LETS method in improving the performance of SAM.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes