Latent Algorithmic Structure Precedes Grokking: A Mechanistic Study of ReLU MLPs on Modular Arithmetic
This mechanistic insight into neural network generalization is incremental, primarily relevant for researchers studying learning dynamics in deep learning.
The study investigates grokking in ReLU MLPs on modular arithmetic, finding that models learn near-binary square wave input weights and output weights with a specific phase-sum relation, even when trained on noisy data without grokking. An idealized model constructed from extracted Fourier components achieves 95.5% accuracy, suggesting grokking sharpens a pre-encoded algorithm rather than discovering it anew.
Grokking-the phenomenon where validation accuracy of neural networks on modular addition of two integers rises long after training data has been memorized-has been characterized in previous works as producing sinusoidal input weight distributions in transformers and multi-layer perceptrons (MLPs). We find empirically that ReLU MLPs in our experimental setting instead learn near-binary square wave input weights, where intermediate-valued weights appear exclusively near sign-change boundaries, alongside output weight distributions whose dominant Fourier phases satisfy a phase-sum relation $Ï_{\mathrm{out}} = Ï_a + Ï_b$; this relation holds even when the model is trained on noisy data and fails to grok. We extract the frequency and phase of each neuron's weights via DFT and construct an idealized MLP: Input weights are replaced by perfect binary square waves and output weights by cosines, both parametrized by the frequencies, phases, and amplitudes extracted from the dominant Fourier components of the real model weights. This idealized model achieves 95.5% accuracy when the frequencies and phases are extracted from the weights of a model trained on noisy data that itself achieves only 0.23% accuracy. This suggests that grokking does not discover the correct algorithm, but rather sharpens an algorithm substantially encoded during memorization, progressively binarizing the input weights into cleaner square waves and aligning the output weights, until generalization becomes possible.