LGJun 30, 2022
Neural Networks can Learn Representations with Gradient DescentAlex Damian, Jason D. Lee, Mahdi Soltanolkotabi
Significant theoretical work has established that in specific regimes, neural networks trained by gradient descent behave like kernel methods. However, in practice, it is known that neural networks strongly outperform their associated kernels. In this work, we explain this gap by demonstrating that there is a large class of functions which cannot be efficiently learned by kernel methods but can be easily learned with gradient descent on a two layer neural network outside the kernel regime by learning representations that are relevant to the target task. We also demonstrate that these representations allow for efficient transfer learning, which is impossible in the kernel regime. Specifically, we consider the problem of learning polynomials which depend on only a few relevant directions, i.e. of the form $f^\star(x) = g(Ux)$ where $U: \R^d \to \R^r$ with $d \gg r$. When the degree of $f^\star$ is $p$, it is known that $n \asymp d^p$ samples are necessary to learn $f^\star$ in the kernel regime. Our primary result is that gradient descent learns a representation of the data which depends only on the directions relevant to $f^\star$. This results in an improved sample complexity of $n\asymp d^2 r + dr^p$. Furthermore, in a transfer learning setup where the data distributions in the source and target domain share the same representation $U$ but have different polynomial heads we show that a popular heuristic for transfer learning has a target sample complexity independent of $d$.
LGSep 30, 2022
Self-Stabilization: The Implicit Bias of Gradient Descent at the Edge of StabilityAlex Damian, Eshaan Nichani, Jason D. Lee
Traditional analyses of gradient descent show that when the largest eigenvalue of the Hessian, also known as the sharpness $S(θ)$, is bounded by $2/η$, training is "stable" and the training loss decreases monotonically. Recent works, however, have observed that this assumption does not hold when training modern neural networks with full batch or large batch gradient descent. Most recently, Cohen et al. (2021) observed two important phenomena. The first, dubbed progressive sharpening, is that the sharpness steadily increases throughout training until it reaches the instability cutoff $2/η$. The second, dubbed edge of stability, is that the sharpness hovers at $2/η$ for the remainder of training while the loss continues decreasing, albeit non-monotonically. We demonstrate that, far from being chaotic, the dynamics of gradient descent at the edge of stability can be captured by a cubic Taylor expansion: as the iterates diverge in direction of the top eigenvector of the Hessian due to instability, the cubic term in the local Taylor expansion of the loss function causes the curvature to decrease until stability is restored. This property, which we call self-stabilization, is a general property of gradient descent and explains its behavior at the edge of stability. A key consequence of self-stabilization is that gradient descent at the edge of stability implicitly follows projected gradient descent (PGD) under the constraint $S(θ) \le 2/η$. Our analysis provides precise predictions for the loss, sharpness, and deviation from the PGD trajectory throughout training, which we verify both empirically in a number of standard settings and theoretically under mild conditions. Our analysis uncovers the mechanism for gradient descent's implicit bias towards stability.
LGFeb 22, 2024
How Transformers Learn Causal Structure with Gradient DescentEshaan Nichani, Alex Damian, Jason D. Lee
The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.
LGMar 6
Improved high-dimensional estimation with Langevin dynamics and stochastic weight averagingStanley Wei, Alex Damian, Jason D. Lee
Significant recent work has studied the ability of gradient descent to recover a hidden planted direction $θ^\star \in S^{d-1}$ in different high-dimensional settings, including tensor PCA and single-index models. The key quantity that governs the ability of gradient descent to traverse these landscapes is the information exponent $k^\star$ (Ben Arous et al., (2021)), which corresponds to the order of the saddle at initialization in the population landscape. Ben Arous et al., (2021) showed that $n \gtrsim d^{\max(1, k^\star-1)}$ samples were necessary and sufficient for online SGD to recover $θ^\star$, and Ben Arous et al., (2020) proved a similar lower bound for Langevin dynamics. More recently, Damian et al., (2023) showed it was possible to circumvent these lower bounds by running gradient descent on a smoothed landscape, and that this algorithm succeeds with $n \gtrsim d^{\max(1, k^\star/2)}$ samples, which is optimal in the worst case. This raises the question of whether it is possible to achieve the same rate without explicit smoothing. In this paper, we show that Langevin dynamics can succeed with $n \gtrsim d^{ k^\star/2 }$ samples if one considers the average iterate, rather than the last iterate. The key idea is that the combination of noise-injection and iterate averaging is able to emulate the effect of landscape smoothing. We apply this result to both the tensor PCA and single-index model settings. Finally, we conjecture that minibatch SGD can also achieve the same rate without adding any additional noise.
LGMar 8, 2024
Computational-Statistical Gaps in Gaussian Single-Index ModelsAlex Damian, Loucas Pillaud-Vivien, Jason D. Lee et al.
Single-Index Models are high-dimensional regression problems with planted structure, whereby labels depend on an unknown one-dimensional projection of the input via a generic, non-linear, and potentially non-deterministic transformation. As such, they encompass a broad class of statistical inference tasks, and provide a rich template to study statistical and computational trade-offs in the high-dimensional regime. While the information-theoretic sample complexity to recover the hidden direction is linear in the dimension $d$, we show that computationally efficient algorithms, both within the Statistical Query (SQ) and the Low-Degree Polynomial (LDP) framework, necessarily require $Ω(d^{k^\star/2})$ samples, where $k^\star$ is a "generative" exponent associated with the model that we explicitly characterize. Moreover, we show that this sample complexity is also sufficient, by establishing matching upper bounds using a partial-trace algorithm. Therefore, our results provide evidence of a sharp computational-to-statistical gap (under both the SQ and LDP class) whenever $k^\star>2$. To complete the study, we provide examples of smooth and Lipschitz deterministic target functions with arbitrarily large generative exponents $k^\star$.
LGOct 31, 2024
Understanding Optimization in Deep Learning with Central FlowsJeremy M. Cohen, Alex Damian, Ameet Talwalkar et al.
Traditional theories of optimization cannot describe the dynamics of optimization in deep learning, even in the simple setting of deterministic training. The challenge is that optimizers typically operate in a complex, oscillatory regime called the "edge of stability." In this paper, we develop theory that can describe the dynamics of optimization in this regime. Our key insight is that while the *exact* trajectory of an oscillatory optimizer may be challenging to analyze, the *time-averaged* (i.e. smoothed) trajectory is often much more tractable. To analyze an optimizer, we derive a differential equation called a "central flow" that characterizes this time-averaged trajectory. We empirically show that these central flows can predict long-term optimization trajectories for generic neural networks with a high degree of numerical accuracy. By interpreting these central flows, we are able to understand how gradient descent makes progress even as the loss sometimes goes up; how adaptive optimizers "adapt" to the local loss landscape; and how adaptive optimizers implicitly navigate towards regions where they can take larger steps. Our results suggest that central flows can be a valuable theoretical tool for reasoning about optimization in deep learning.
LGMay 29, 2025
Learning Compositional Functions with Transformers from Easy-to-Hard DataZixuan Wang, Eshaan Nichani, Alberto Bietti et al.
Transformer-based language models have demonstrated impressive capabilities across a range of complex reasoning tasks. Prior theoretical work exploring the expressive power of transformers has shown that they can efficiently perform multi-step reasoning tasks involving parallelizable computations. However, the learnability of such constructions, particularly the conditions on the data distribution that enable efficient learning via gradient-based optimization, remains an open question. Towards answering this question, in this work we study the learnability of the $k$-fold composition task, which requires computing an interleaved composition of $k$ input permutations and $k$ hidden permutations, and can be expressed by a transformer with $O(\log k)$ layers. On the negative front, we prove a Statistical Query (SQ) lower bound showing that any SQ learner that makes only polynomially-many queries to an SQ oracle for the $k$-fold composition task distribution must have sample size exponential in $k$, thus establishing a statistical-computational gap. On the other hand, we show that this function class can be efficiently learned, with runtime and sample complexity polynomial in $k$, by gradient descent on an $O(\log k)$-depth transformer via two different curriculum learning strategies: one in which data consists of $k'$-fold composition functions with $k' \le k$ presented in increasing difficulty, and another in which all such data is presented simultaneously. Our work sheds light on the necessity and sufficiency of having both easy and hard examples in the data distribution for transformers to learn complex compositional tasks.
LGJun 5, 2025
The Generative Leap: Sharp Sample Complexity for Efficiently Learning Gaussian Multi-Index ModelsAlex Damian, Jason D. Lee, Joan Bruna
In this work we consider generic Gaussian Multi-index models, in which the labels only depend on the (Gaussian) $d$-dimensional inputs through their projection onto a low-dimensional $r = O_d(1)$ subspace, and we study efficient agnostic estimation procedures for this hidden subspace. We introduce the \emph{generative leap} exponent $k^\star$, a natural extension of the generative exponent from [Damian et al.'24] to the multi-index setting. We first show that a sample complexity of $n=Θ(d^{1 \vee \k/2})$ is necessary in the class of algorithms captured by the Low-Degree-Polynomial framework. We then establish that this sample complexity is also sufficient, by giving an agnostic sequential estimation procedure (that is, requiring no prior knowledge of the multi-index model) based on a spectral U-statistic over appropriate Hermite tensors. We further compute the generative leap exponent for several examples including piecewise linear functions (deep ReLU networks with bias), and general deep neural networks (with $r$-dimensional first hidden layer).
LGMay 27, 2023
Fine-Tuning Language Models with Just Forward PassesSadhika Malladi, Tianyu Gao, Eshaan Nichani et al.
Fine-tuning language models (LMs) has yielded success on diverse downstream tasks, but as LMs grow in size, backpropagation requires a prohibitively large amount of memory. Zeroth-order (ZO) methods can in principle estimate gradients using only two forward passes but are theorized to be catastrophically slow for optimizing large models. In this work, we propose a memory-efficient zerothorder optimizer (MeZO), adapting the classical ZO-SGD method to operate in-place, thereby fine-tuning LMs with the same memory footprint as inference. For example, with a single A100 80GB GPU, MeZO can train a 30-billion parameter model, whereas fine-tuning with backpropagation can train only a 2.7B LM with the same budget. We conduct comprehensive experiments across model types (masked and autoregressive LMs), model scales (up to 66B), and downstream tasks (classification, multiple-choice, and generation). Our results demonstrate that (1) MeZO significantly outperforms in-context learning and linear probing; (2) MeZO achieves comparable performance to fine-tuning with backpropagation across multiple tasks, with up to 12x memory reduction and up to 2x GPU-hour reduction in our implementation; (3) MeZO is compatible with both full-parameter and parameter-efficient tuning techniques such as LoRA and prefix tuning; (4) MeZO can effectively optimize non-differentiable objectives (e.g., maximizing accuracy or F1). We support our empirical findings with theoretical insights, highlighting how adequate pre-training and task prompts enable MeZO to fine-tune huge models, despite classical ZO analyses suggesting otherwise.
LGMay 18, 2023
Smoothing the Landscape Boosts the Signal for SGD: Optimal Sample Complexity for Learning Single Index ModelsAlex Damian, Eshaan Nichani, Rong Ge et al.
We focus on the task of learning a single index model $σ(w^\star \cdot x)$ with respect to the isotropic Gaussian distribution in $d$ dimensions. Prior work has shown that the sample complexity of learning $w^\star$ is governed by the information exponent $k^\star$ of the link function $σ$, which is defined as the index of the first nonzero Hermite coefficient of $σ$. Ben Arous et al. (2021) showed that $n \gtrsim d^{k^\star-1}$ samples suffice for learning $w^\star$ and that this is tight for online SGD. However, the CSQ lower bound for gradient based methods only shows that $n \gtrsim d^{k^\star/2}$ samples are necessary. In this work, we close the gap between the upper and lower bounds by showing that online SGD on a smoothed loss learns $w^\star$ with $n \gtrsim d^{k^\star/2}$ samples. We also draw connections to statistical analyses of tensor PCA and to the implicit regularization effects of minibatch SGD on empirical losses.
LGMay 11, 2023
Provable Guarantees for Nonlinear Feature Learning in Three-Layer Neural NetworksEshaan Nichani, Alex Damian, Jason D. Lee
One of the central questions in the theory of deep learning is to understand how neural networks learn hierarchical features. The ability of deep networks to extract salient features is crucial to both their outstanding generalization ability and the modern deep learning paradigm of pretraining and finetuneing. However, this feature learning process remains poorly understood from a theoretical perspective, with existing analyses largely restricted to two-layer networks. In this work we show that three-layer neural networks have provably richer feature learning capabilities than two-layer networks. We analyze the features learned by a three-layer network trained with layer-wise gradient descent, and present a general purpose theorem which upper bounds the sample complexity and width needed to achieve low test error when the target has specific hierarchical structure. We instantiate our framework in specific statistical learning settings -- single-index models and functions of quadratic features -- and show that in the latter setting three-layer networks obtain a sample complexity improvement over all existing guarantees for two-layer networks. Crucially, this sample complexity improvement relies on the ability of three-layer networks to efficiently learn nonlinear features. We then establish a concrete optimization-based depth separation by constructing a function which is efficiently learnable via gradient descent on a three-layer network, yet cannot be learned efficiently by a two-layer network. Our work makes progress towards understanding the provable benefit of three-layer neural networks over two-layer networks in the feature learning regime.
LGJun 11, 2021
Label Noise SGD Provably Prefers Flat Global MinimizersAlex Damian, Tengyu Ma, Jason D. Lee
In overparametrized models, the noise in stochastic gradient descent (SGD) implicitly regularizes the optimization trajectory and determines which local minimum SGD converges to. Motivated by empirical studies that demonstrate that training with noisy labels improves generalization, we study the implicit regularization effect of SGD with label noise. We show that SGD with label noise converges to a stationary point of a regularized loss $L(θ) +λR(θ)$, where $L(θ)$ is the training loss, $λ$ is an effective regularization parameter depending on the step size, strength of the label noise, and the batch size, and $R(θ)$ is an explicit regularizer that penalizes sharp minimizers. Our analysis uncovers an additional regularization effect of large learning rates beyond the linear scaling rule that penalizes large eigenvalues of the Hessian more than small ones. We also prove extensions to classification with general loss functions, SGD with momentum, and SGD with general noise covariance, significantly strengthening the prior work of Blanc et al. to global convergence and large learning rates and of HaoChen et al. to general models.
CVMay 9, 2018
New Techniques for Preserving Global Structure and Denoising with Low Information Loss in Single-Image Super-ResolutionYijie Bei, Alex Damian, Shijia Hu et al.
This work identifies and addresses two important technical challenges in single-image super-resolution: (1) how to upsample an image without magnifying noise and (2) how to preserve large scale structure when upsampling. We summarize the techniques we developed for our second place entry in Track 1 (Bicubic Downsampling), seventh place entry in Track 2 (Realistic Adverse Conditions), and seventh place entry in Track 3 (Realistic difficult) in the 2018 NTIRE Super-Resolution Challenge. Furthermore, we present new neural network architectures that specifically address the two challenges listed above: denoising and preservation of large-scale structure.