On Learning Gaussian Multi-index Models with Gradient Flow
This addresses feature learning in neural networks by analyzing convergence and optimization challenges in multi-index models, representing an incremental theoretical advance.
The paper tackles the problem of learning multi-index models with gradient flow on high-dimensional Gaussian data, establishing global convergence for a two-timescale algorithm and characterizing saddle-to-saddle dynamics, while showing that a related planted problem has a rough optimization landscape where gradient flow can get trapped.
We study gradient flow on the multi-index regression problem for high-dimensional Gaussian data. Multi-index functions consist of a composition of an unknown low-rank linear projection and an arbitrary unknown, low-dimensional link function. As such, they constitute a natural template for feature learning in neural networks. We consider a two-timescale algorithm, whereby the low-dimensional link function is learnt with a non-parametric model infinitely faster than the subspace parametrizing the low-rank projection. By appropriately exploiting the matrix semigroup structure arising over the subspace correlation matrices, we establish global convergence of the resulting Grassmannian population gradient flow dynamics, and provide a quantitative description of its associated `saddle-to-saddle' dynamics. Notably, the timescales associated with each saddle can be explicitly characterized in terms of an appropriate Hermite decomposition of the target link function. In contrast with these positive results, we also show that the related \emph{planted} problem, where the link function is known and fixed, in fact has a rough optimization landscape, in which gradient flow dynamics might get trapped with high probability.