Learning quadratic neural networks in high dimensions: SGD dynamics and scaling laws
This provides theoretical insights into gradient-based training dynamics for neural networks, addressing a foundational problem in machine learning theory.
The paper tackles the optimization and sample complexity of training a two-layer neural network with quadratic activation in high dimensions, deriving scaling laws for prediction risk that depend on optimization time, sample size, and model width.
We study the optimization and sample complexity of gradient-based training of a two-layer neural network with quadratic activation function in the high-dimensional regime, where the data is generated as $y \propto \sum_{j=1}^{r}λ_j σ\left(\langle \boldsymbol{θ_j}, \boldsymbol{x}\rangle\right), \boldsymbol{x} \sim N(0,\boldsymbol{I}_d)$, $σ$ is the 2nd Hermite polynomial, and $\lbrace\boldsymbolθ_j \rbrace_{j=1}^{r} \subset \mathbb{R}^d$ are orthonormal signal directions. We consider the extensive-width regime $r \asymp d^β$ for $β\in [0, 1)$, and assume a power-law decay on the (non-negative) second-layer coefficients $λ_j\asymp j^{-α}$ for $α\geq 0$. We present a sharp analysis of the SGD dynamics in the feature learning regime, for both the population limit and the finite-sample (online) discretization, and derive scaling laws for the prediction risk that highlight the power-law dependencies on the optimization time, sample size, and model width. Our analysis combines a precise characterization of the associated matrix Riccati differential equation with novel matrix monotonicity arguments to establish convergence guarantees for the infinite-dimensional effective dynamics.