Gradient-Based Feature Learning under Structured Data
This work addresses the challenge of efficient feature learning in structured data for machine learning practitioners, offering a novel approach to improve training in anisotropic settings.
The paper tackles the problem of gradient-based learning for single index models under anisotropic data with spiked covariance, showing that standard spherical gradient dynamics can fail even with aligned spikes, but weight normalization similar to batch normalization can resolve this and achieve sample complexity independent of the information exponent, outperforming kernel method lower bounds.
Recent works have demonstrated that the sample complexity of gradient-based learning of single index models, i.e. functions that depend on a 1-dimensional projection of the input data, is governed by their information exponent. However, these results are only concerned with isotropic data, while in practice the input often contains additional structure which can implicitly guide the algorithm. In this work, we investigate the effect of a spiked covariance structure and reveal several interesting phenomena. First, we show that in the anisotropic setting, the commonly used spherical gradient dynamics may fail to recover the true direction, even when the spike is perfectly aligned with the target direction. Next, we show that appropriate weight normalization that is reminiscent of batch normalization can alleviate this issue. Further, by exploiting the alignment between the (spiked) input covariance and the target, we obtain improved sample complexity compared to the isotropic case. In particular, under the spiked model with a suitably large spike, the sample complexity of gradient-based training can be made independent of the information exponent while also outperforming lower bounds for rotationally invariant kernel methods.