MLLGJun 4, 2021

Learning Curves for SGD on Structured Features

arXiv:2106.02713v55 citations
Originality Incremental advance
AI Analysis

This work provides theoretical insights into SGD optimization for machine learning practitioners, though it is incremental as it builds on existing models to analyze data structure effects.

The authors developed an exactly solvable model of stochastic gradient descent (SGD) on mean square loss to analyze how data structure affects test loss dynamics, showing it accurately predicts performance on real datasets like MNIST and CIFAR-10 and that optimal batch sizes are typically small and depend on feature correlations.

The generalization performance of a machine learning algorithm such as a neural network depends in a non-trivial way on the structure of the data distribution. To analyze the influence of data structure on test loss dynamics, we study an exactly solveable model of stochastic gradient descent (SGD) on mean square loss which predicts test loss when training on features with arbitrary covariance structure. We solve the theory exactly for both Gaussian features and arbitrary features and we show that the simpler Gaussian model accurately predicts test loss of nonlinear random-feature models and deep neural networks trained with SGD on real datasets such as MNIST and CIFAR-10. We show that the optimal batch size at a fixed compute budget is typically small and depends on the feature correlation structure, demonstrating the computational benefits of SGD with small batch sizes. Lastly, we extend our theory to the more usual setting of stochastic gradient descent on a fixed subsampled training set, showing that both training and test error can be accurately predicted in our framework on real data.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes