Grokking modular arithmetic
This addresses the interpretability of neural network representations in learning modular arithmetic, which is incremental as it builds on prior work on grokking phenomena.
The paper tackles the problem of neural networks learning modular arithmetic tasks and exhibiting a sudden jump in generalization called 'grokking', showing that fully-connected two-layer networks achieve this under vanilla gradient descent without regularization. The result includes analytic expressions for weights and feature maps that solve these tasks, with evidence that these maps are found by gradient descent and AdamW, providing interpretability.
We present a simple neural network that can learn modular arithmetic tasks and exhibits a sudden jump in generalization known as ``grokking''. Concretely, we present (i) fully-connected two-layer networks that exhibit grokking on various modular arithmetic tasks under vanilla gradient descent with the MSE loss function in the absence of any regularization; (ii) evidence that grokking modular arithmetic corresponds to learning specific feature maps whose structure is determined by the task; (iii) analytic expressions for the weights -- and thus for the feature maps -- that solve a large class of modular arithmetic tasks; and (iv) evidence that these feature maps are also found by vanilla gradient descent as well as AdamW, thereby establishing complete interpretability of the representations learnt by the network.