MLAILGFeb 7, 2024

Feature learning as alignment: a structural property of gradient descent in non-linear neural networks

arXiv:2402.05271v46 citationsh-index: 11Trans. Mach. Learn. Res.
AI Analysis

This work provides theoretical insights into feature learning mechanisms in neural networks, addressing a fundamental problem in supervised learning for researchers in machine learning theory.

The authors tackled the problem of explaining the correlation between neural feature matrices and average gradient outer products during training, known as the neural feature ansatz, by showing it arises from alignment between weight matrix singular structures and pre-activation tangent features, driven by SGD interactions, and they introduced an optimization rule that increased NFA correlations and improved feature quality.

Understanding the mechanisms through which neural networks extract statistics from input-label pairs through feature learning is one of the most important unsolved problems in supervised learning. Prior works demonstrated that the gram matrices of the weights (the neural feature matrices, NFM) and the average gradient outer products (AGOP) become correlated during training, in a statement known as the neural feature ansatz (NFA). Through the NFA, the authors introduce mapping with the AGOP as a general mechanism for neural feature learning. However, these works do not provide a theoretical explanation for this correlation or its origins. In this work, we further clarify the nature of this correlation, and explain its emergence. We show that this correlation is equivalent to alignment between the left singular structure of the weight matrices and the newly defined pre-activation tangent features at each layer. We further establish that the alignment is driven by the interaction of weight changes induced by SGD with the pre-activation features, and analyze the resulting dynamics analytically at early times in terms of simple statistics of the inputs and labels. We prove the derivative alignment occurs almost surely in specific high dimensional settings. Finally, we introduce a simple optimization rule motivated by our analysis of the centered correlation which dramatically increases the NFA correlations at any given layer and improves the quality of features learned.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes