MLLGOct 19, 2022

Rethinking Sharpness-Aware Minimization as Variational Inference

arXiv:2210.10452v19 citationsh-index: 6
Originality Incremental advance
AI Analysis

This work provides a theoretical link between optimization and variational inference for improving generalization in machine learning, though it appears incremental as it builds on existing methods.

The paper connects Sharpness-Aware Minimization (SAM) and Mean-Field Variational Inference (MFVI) by showing both optimize flatness and involve gradient calculations at perturbed parameters, leading to new variational algorithms that combine or interpolate between them, with performance evaluated on benchmark datasets against SAM variants.

Sharpness-aware minimization (SAM) aims to improve the generalisation of gradient-based learning by seeking out flat minima. In this work, we establish connections between SAM and Mean-Field Variational Inference (MFVI) of neural network parameters. We show that both these methods have interpretations as optimizing notions of flatness, and when using the reparametrisation trick, they both boil down to calculating the gradient at a perturbed version of the current mean parameter. This thinking motivates our study of algorithms that combine or interpolate between SAM and MFVI. We evaluate the proposed variational algorithms on several benchmark datasets, and compare their performance to variants of SAM. Taking a broader perspective, our work suggests that SAM-like updates can be used as a drop-in replacement for the reparametrisation trick.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes