Generalization and Stability of Interpolating Neural Networks with Minimal Width
This provides theoretical guarantees for neural network training with minimal width, addressing stability and generalization gaps in non-convex settings for machine learning researchers.
The paper tackles the generalization and optimization of shallow neural networks in the interpolating regime, showing that gradient descent achieves training error O(g(1/T)^2/T) and generalization error O(g(1/T)^2/n) with minimal width requirements, such as m=Ω(log^4(n)) neurons for test loss bound O(1/n).
We investigate the generalization and optimization properties of shallow neural-network classifiers trained by gradient descent in the interpolating regime. Specifically, in a realizable scenario where model weights can achieve arbitrarily small training error $ε$ and their distance from initialization is $g(ε)$, we demonstrate that gradient descent with $n$ training data achieves training error $O(g(1/T)^2 /T)$ and generalization error $O(g(1/T)^2 /n)$ at iteration $T$, provided there are at least $m=Ω(g(1/T)^4)$ hidden neurons. We then show that our realizable setting encompasses a special case where data are separable by the model's neural tangent kernel. For this and logistic-loss minimization, we prove the training loss decays at a rate of $\tilde O(1/ T)$ given polylogarithmic number of neurons $m=Ω(\log^4 (T))$. Moreover, with $m=Ω(\log^{4} (n))$ neurons and $T\approx n$ iterations, we bound the test loss by $\tilde{O}(1/n)$. Our results differ from existing generalization outcomes using the algorithmic-stability framework, which necessitate polynomial width and yield suboptimal generalization rates. Central to our analysis is the use of a new self-bounded weak-convexity property, which leads to a generalized local quasi-convexity property for sufficiently parameterized neural-network classifiers. Eventually, despite the objective's non-convexity, this leads to convergence and generalization-gap bounds that resemble those found in the convex setting of linear logistic regression.