The merged-staircase property: a necessary and nearly sufficient condition for SGD learning of sparse functions on two-layer neural networks
This work addresses a key gap in understanding how neural networks efficiently learn high-dimensional data with latent low-dimensional structure, providing theoretical insights for machine learning researchers, though it is incremental as it builds on prior characterizations for extreme cases.
The paper tackles the problem of characterizing which sparse functions on binary inputs can be learned by two-layer neural networks trained with SGD in the mean-field regime, showing that a 'merged-staircase property' is necessary and nearly sufficient for efficient learning with O(d) sample complexity in high dimensions. It also proves that non-linear training is essential, as linear methods like NTK fail for this class of functions.
It is currently known how to characterize functions that neural networks can learn with SGD for two extremal parameterizations: neural networks in the linear regime, and neural networks with no structural constraints. However, for the main parametrization of interest (non-linear but regular networks) no tight characterization has yet been achieved, despite significant developments. We take a step in this direction by considering depth-2 neural networks trained by SGD in the mean-field regime. We consider functions on binary inputs that depend on a latent low-dimensional subspace (i.e., small number of coordinates). This regime is of interest since it is poorly understood how neural networks routinely tackle high-dimensional datasets and adapt to latent low-dimensional structure without suffering from the curse of dimensionality. Accordingly, we study SGD-learnability with $O(d)$ sample complexity in a large ambient dimension $d$. Our main results characterize a hierarchical property, the "merged-staircase property", that is both necessary and nearly sufficient for learning in this setting. We further show that non-linear training is necessary: for this class of functions, linear methods on any feature map (e.g., the NTK) are not capable of learning efficiently. The key tools are a new "dimension-free" dynamics approximation result that applies to functions defined on a latent space of low-dimension, a proof of global convergence based on polynomial identity testing, and an improvement of lower bounds against linear methods for non-almost orthogonal functions.