LGMay 25, 2023

Sharpness-Aware Minimization Leads to Low-Rank Features

arXiv:2305.16292v244 citationsHas Code
Originality Incremental advance
AI Analysis

This provides a mechanistic insight into SAM's effects on feature learning, which is incremental but useful for understanding generalization in deep learning.

The paper investigates how Sharpness-Aware Minimization (SAM) reduces the rank of features in neural networks across various architectures and tasks, showing that it prunes activations to achieve this effect.

Sharpness-aware minimization (SAM) is a recently proposed method that minimizes the sharpness of the training loss of a neural network. While its generalization improvement is well-known and is the primary motivation, we uncover an additional intriguing effect of SAM: reduction of the feature rank which happens at different layers of a neural network. We show that this low-rank effect occurs very broadly: for different architectures such as fully-connected networks, convolutional networks, vision transformers and for different objectives such as regression, classification, language-image contrastive training. To better understand this phenomenon, we provide a mechanistic understanding of how low-rank features arise in a simple two-layer network. We observe that a significant number of activations gets entirely pruned by SAM which directly contributes to the rank reduction. We confirm this effect theoretically and check that it can also occur in deep networks, although the overall rank reduction mechanism can be more complex, especially for deep networks with pre-activation skip connections and self-attention layers. We make our code available at https://github.com/tml-epfl/sam-low-rank-features.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes