Neural Networks Learn Statistics of Increasing Complexity
This work addresses the fundamental understanding of how neural networks learn hierarchical features, which is incremental but relevant for researchers in machine learning and AI.
The paper provides new evidence for the distributional simplicity bias, showing that neural networks learn low-order statistics of data early in training before progressing to higher-order correlations, and extends this to discrete domains with empirical validation in LLMs.
The distributional simplicity bias (DSB) posits that neural networks learn low-order moments of the data distribution first, before moving on to higher-order correlations. In this work, we present compelling new evidence for the DSB by showing that networks automatically learn to perform well on maximum-entropy distributions whose low-order statistics match those of the training set early in training, then lose this ability later. We also extend the DSB to discrete domains by proving an equivalence between token $n$-gram frequencies and the moments of embedding vectors, and by finding empirical evidence for the bias in LLMs. Finally we use optimal transport methods to surgically edit the low-order statistics of one class to match those of another, and show that early-training networks treat the edited samples as if they were drawn from the target class. Code is available at https://github.com/EleutherAI/features-across-time.