LGITMLJun 30, 2022

Neural Networks can Learn Representations with Gradient Descent

arXiv:2206.15144v1189 citationsh-index: 40
Originality Highly original
AI Analysis

This addresses a fundamental theoretical problem in machine learning by explaining why neural networks outperform kernel methods in practice, with implications for representation learning and transfer efficiency.

The paper tackles the gap between neural networks and kernel methods by showing that gradient descent on two-layer neural networks can learn representations for functions like low-degree polynomials depending on few relevant directions, achieving improved sample complexity of n ∝ d²r + dr^p compared to kernel methods requiring n ∝ d^p, and enabling efficient transfer learning with target sample complexity independent of d.

Significant theoretical work has established that in specific regimes, neural networks trained by gradient descent behave like kernel methods. However, in practice, it is known that neural networks strongly outperform their associated kernels. In this work, we explain this gap by demonstrating that there is a large class of functions which cannot be efficiently learned by kernel methods but can be easily learned with gradient descent on a two layer neural network outside the kernel regime by learning representations that are relevant to the target task. We also demonstrate that these representations allow for efficient transfer learning, which is impossible in the kernel regime. Specifically, we consider the problem of learning polynomials which depend on only a few relevant directions, i.e. of the form $f^\star(x) = g(Ux)$ where $U: \R^d \to \R^r$ with $d \gg r$. When the degree of $f^\star$ is $p$, it is known that $n \asymp d^p$ samples are necessary to learn $f^\star$ in the kernel regime. Our primary result is that gradient descent learns a representation of the data which depends only on the directions relevant to $f^\star$. This results in an improved sample complexity of $n\asymp d^2 r + dr^p$. Furthermore, in a transfer learning setup where the data distributions in the source and target domain share the same representation $U$ but have different polynomial heads we show that a popular heuristic for transfer learning has a target sample complexity independent of $d$.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes