Norm-Hierarchy Transitions in Representation Learning: When and Why Neural Networks Abandon Shortcuts
This work provides a unified explanation for phenomena like grokking and shortcut learning, which is significant for researchers trying to understand and control the learning dynamics of neural networks.
Neural networks initially rely on spurious shortcuts but eventually discover structured representations. This paper introduces the Norm-Hierarchy Transition (NHT) framework, explaining this delay as the slow traversal of a hierarchy of parameter norms during regularized optimization, moving from high-norm shortcut solutions to lower-norm structured representations.
Neural networks often rely on spurious shortcuts for many epochs before discovering structured representations. However, the mechanism governing when this transition occurs and whether its timing can be predicted remains unclear. Prior work shows that gradient descent converges to low norm solutions and that neural networks exhibit simplicity bias, but neither explains the timescale of the transition from shortcut features to structured representations. We introduce the Norm-Hierarchy Transition (NHT) framework, which explains delayed representation learning as the slow traversal of a hierarchy of parameter norms during regularized optimization. When multiple interpolating solutions exist with different norms, weight decay gradually moves the model from high norm shortcut solutions toward lower norm structured representations. We derive a tight bound showing that the transition delay grows logarithmically with the ratio between shortcut and structured norms. Experiments on modular arithmetic, CIFAR-10 with spurious features, CelebA, and Waterbirds support the predictions of the framework. The results suggest that grokking, shortcut learning, and delayed feature discovery arise from a common mechanism based on norm hierarchy traversal during training.