Scaling Laws for Gradient Descent and Sign Descent for Linear Bigram Models under Zipf's Law
This work addresses optimization challenges in training transformer-based language models, particularly for large vocabularies, by providing theoretical scaling laws that explain why methods like Adam outperform gradient descent, though it is incremental as it builds on prior studies of data distribution effects.
The paper tackles the optimization difficulty in training language models due to heavy-tailed data distributions like Zipf's law, showing that for linear bigram models, gradient descent requires almost linear scaling with dimension for Zipf-distributed data, while sign descent scales only with the square-root of dimension, leading to significant improvements for large vocabularies.
Recent works have highlighted optimization difficulties faced by gradient descent in training the first and last layers of transformer-based language models, which are overcome by optimizers such as Adam. These works suggest that the difficulty is linked to the heavy-tailed distribution of words in text data, where the frequency of the $k$th most frequent word $π_k$ is proportional to $1/k$, following Zipf's law. To better understand the impact of the data distribution on training performance, we study a linear bigram model for next-token prediction when the tokens follow a power law $π_k \propto 1/k^α$ parameterized by the exponent $α> 0$. We derive optimization scaling laws for deterministic gradient descent and sign descent as a proxy for Adam as a function of the exponent $α$. Existing theoretical investigations in scaling laws assume that the eigenvalues of the data decay as a power law with exponent $α> 1$. This assumption effectively makes the problem ``finite dimensional'' as most of the loss comes from a few of the largest eigencomponents. In comparison, we show that the problem is more difficult when the data have heavier tails. The case $α= 1$ as found in text data is ``worst-case'' for gradient descent, in that the number of iterations required to reach a small relative error scales almost linearly with dimension. While the performance of sign descent also depends on the dimension, for Zipf-distributed data the number of iterations scales only with the square-root of the dimension, leading to a large improvement for large vocabularies.