Jason D. Lee

LG
h-index96
158papers
15,936citations
Novelty63%
AI Score64

158 Papers

AIJun 3Code
LeanMarathon: Toward Reliable AI Co-Mathematicians through Long-Horizon Lean Autoformalization

Yuanhe Zhang, Yuekai Sun, Taiji Suzuki et al.

Long-horizon autoformalization of research mathematics fails not only at hard lemmas, but at scale: statements drift, dependencies tangle, context decays, and local repairs corrupt distant work. We present LeanMarathon, a multi-agent harness for reliable research-level Lean autoformalization. Its core abstraction is an evolving blueprint: a Lean file that serves simultaneously as formal proof skeleton, natural-language proof graph, and shared system of record. Four contract-scoped agents construct, audit, prove, and repair this blueprint. These agents are coordinated by a two-stage orchestrator that first stabilizes target fidelity through adversarial review and then discharges the proof directed acyclic graph (DAG) from its dynamic leaves upward in parallel CI-gated rounds. LeanMarathon turns one brittle multi-hour run into many local, recoverable, parallel transactions. We evaluate LeanMarathon on two recent research papers spanning four Erdős problems (#1051, #1196, #164, #1217). Across three autonomous runs, it formalizes all seven target theorems with no sorry, proving 258 lemmas and theorems. These results show that reliable AI co-mathematics requires not only stronger provers, but durable harnesses that preserve target fidelity across long mathematical developments. The code can be found at https://github.com/YuanheZ/LeanMarathon.

LGJan 27, 2023
Understanding Incremental Learning of Gradient Descent: A Fine-grained Analysis of Matrix Sensing

Jikai Jin, Zhiyuan Li, Kaifeng Lyu et al. · stanford, tsinghua

It is believed that Gradient Descent (GD) induces an implicit bias towards good generalization in training machine learning models. This paper provides a fine-grained analysis of the dynamics of GD for the matrix sensing problem, whose goal is to recover a low-rank ground-truth matrix from near-isotropic linear measurements. It is shown that GD with small initialization behaves similarly to the greedy low-rank learning heuristics (Li et al., 2020) and follows an incremental learning procedure (Gissin et al., 2019): GD sequentially learns solutions with increasing ranks until it recovers the ground truth matrix. Compared to existing works which only analyze the first learning phase for rank-1 solutions, our result provides characterizations for the whole learning process. Moreover, besides the over-parameterized regime that many prior works focused on, our analysis of the incremental learning procedure also applies to the under-parameterized regime. Finally, we conduct numerical experiments to confirm our theoretical findings.

CLNov 14, 2023Code
REST: Retrieval-Based Speculative Decoding

Zhenyu He, Zexuan Zhong, Tianle Cai et al.

We introduce Retrieval-Based Speculative Decoding (REST), a novel algorithm designed to speed up language model generation. The key insight driving the development of REST is the observation that the process of text generation often includes certain common phases and patterns. Unlike previous methods that rely on a draft language model for speculative decoding, REST harnesses the power of retrieval to generate draft tokens. This method draws from the reservoir of existing knowledge, retrieving and employing relevant tokens based on the current context. Its plug-and-play nature allows for seamless integration and acceleration of any language models, all without necessitating additional training. When benchmarked on 7B and 13B language models in a single-batch setting, REST achieves a significant speedup of 1.62X to 2.36X on code or text generation. The code of REST is available at https://github.com/FasterDecoding/REST.

LGNov 30, 2023
Dichotomy of Early and Late Phase Implicit Biases Can Provably Induce Grokking

Kaifeng Lyu, Jikai Jin, Zhiyuan Li et al. · stanford, tsinghua

Recent work by Power et al. (2022) highlighted a surprising "grokking" phenomenon in learning arithmetic tasks: a neural net first "memorizes" the training set, resulting in perfect training accuracy but near-random test accuracy, and after training for sufficiently longer, it suddenly transitions to perfect test accuracy. This paper studies the grokking phenomenon in theoretical setups and shows that it can be induced by a dichotomy of early and late phase implicit biases. Specifically, when training homogeneous neural nets with large initialization and small weight decay on both classification and regression tasks, we prove that the training process gets trapped at a solution corresponding to a kernel predictor for a long time, and then a very sharp transition to min-norm/max-margin predictors occurs, leading to a dramatic change in test accuracy.

LGMay 26
Fine-Tuning Dynamics of In-Context Factual Recall in Transformers

Ruomin Huang, Eshaan Nichani, Jason D. Lee et al.

In-context learning \ -- performing tasks based on examples given in the prompt \ -- is an important capability that has emerged in large language models and has received significant attention in both theory and practice. Existing theoretical work often focuses on settings where the learning uses information purely from the prompt. However, many practical instances of in-context learning require the model to retrieve factual knowledge stored in the model's parameters, with the context serving to identify which knowledge is relevant. In this work, we study how in-context learning leverages factual knowledge recall. We formalize this behavior by introducing the \emph{in-context factual recall (IC-recall)} task, where a transformer is provided a context of (subject, answer) pairs generated from a hidden relation, along with a query subject, and must both infer this hidden relation and retrieve the corresponding answer. Factual knowledge is modeled by the transformer having access to a simple pre-constructed MLP associative memory storing (subject, relation, answer) triplets. We analyze the supervised fine-tuning dynamics of a one-layer transformer on IC-recall data and prove that the model successfully performs IC-recall by converging to a particular pairwise attention pattern. This fine-tuning stage requires a very small number of samples \ -- only polylogarithmic in the number of stored knowledge triplets. Experiments verify our theoretical predictions and show that the pairwise attention pattern emerges even when the MLP layer is pretrained instead of constructed.

MLMar 17, 2014
Proximal Newton-type methods for minimizing composite functions

Jason D. Lee, Yuekai Sun, Michael A. Saunders

We generalize Newton-type methods for minimizing smooth functions to handle a sum of two convex functions: a smooth function and a nonsmooth function with a simple proximal mapping. We show that the resulting proximal Newton-type methods inherit the desirable convergence behavior of Newton-type methods for minimizing smooth functions, even when search directions are computed inexactly. Many popular methods tailored to problems arising in bioinformatics, signal processing, and statistical learning are special cases of proximal Newton-type methods, and our analysis yields new convergence results for some of these methods.

LGJul 12, 2022
PAC Reinforcement Learning for Predictive State Representations

Wenhao Zhan, Masatoshi Uehara, Wen Sun et al. · harvard

In this paper we study online Reinforcement Learning (RL) in partially observable dynamical systems. We focus on the Predictive State Representations (PSRs) model, which is an expressive model that captures other well-known models such as Partially Observable Markov Decision Processes (POMDP). PSR represents the states using a set of predictions of future observations and is defined entirely using observable quantities. We develop a novel model-based algorithm for PSRs that can learn a near optimal policy in sample complexity scaling polynomially with respect to all the relevant parameters of the systems. Our algorithm naturally works with function approximation to extend to systems with potentially large state and observation spaces. We show that given a realizable model class, the sample complexity of learning the near optimal policy only scales polynomially with respect to the statistical complexity of the model class, without any explicit polynomial dependence on the size of the state and observation spaces. Notably, our work is the first work that shows polynomial sample complexities to compete with the globally optimal policy in PSRs. Finally, we demonstrate how our general theorem can be directly used to derive sample complexity bounds for special models including $m$-step weakly revealing and $m$-step decodable tabular POMDPs, POMDPs with low-rank latent transition, and POMDPs with linear emission and latent transition.

LGJun 24, 2022
Provably Efficient Reinforcement Learning in Partially Observable Dynamical Systems

Masatoshi Uehara, Ayush Sekhari, Jason D. Lee et al. · harvard

We study Reinforcement Learning for partially observable dynamical systems using function approximation. We propose a new \textit{Partially Observable Bilinear Actor-Critic framework}, that is general enough to include models such as observable tabular Partially Observable Markov Decision Processes (POMDPs), observable Linear-Quadratic-Gaussian (LQG), Predictive State Representations (PSRs), as well as a newly introduced model Hilbert Space Embeddings of POMDPs and observable POMDPs with latent low-rank transition. Under this framework, we propose an actor-critic style algorithm that is capable of performing agnostic policy learning. Given a policy class that consists of memory based policies (that look at a fixed-length window of recent observations), and a value function class that consists of functions taking both memory and future observations as inputs, our algorithm learns to compete against the best memory-based policy in the given policy class. For certain examples such as undercomplete observable tabular POMDPs, observable LQGs and observable POMDPs with latent low-rank transition, by implicitly leveraging their special properties, our algorithm is even capable of competing against the globally optimal policy without paying an exponential dependence on the horizon in its sample complexity.

LGMar 29, 2022
Nearly Minimax Algorithms for Linear Bandits with Shared Representation

Jiaqi Yang, Qi Lei, Jason D. Lee et al. · tsinghua

We give novel algorithms for multi-task and lifelong linear bandits with shared representation. Specifically, we consider the setting where we play $M$ linear bandits with dimension $d$, each for $T$ rounds, and these $M$ bandit tasks share a common $k(\ll d)$ dimensional linear representation. For both the multi-task setting where we play the tasks concurrently, and the lifelong setting where we play tasks sequentially, we come up with novel algorithms that achieve $\widetilde{O}\left(d\sqrt{kMT} + kM\sqrt{T}\right)$ regret bounds, which matches the known minimax regret lower bound up to logarithmic factors and closes the gap in existing results [Yang et al., 2021]. Our main technique include a more efficient estimator for the low-rank linear feature extractor and an accompanied novel analysis for this estimator.

LGJun 30, 2022
Neural Networks can Learn Representations with Gradient Descent

Alex Damian, Jason D. Lee, Mahdi Soltanolkotabi

Significant theoretical work has established that in specific regimes, neural networks trained by gradient descent behave like kernel methods. However, in practice, it is known that neural networks strongly outperform their associated kernels. In this work, we explain this gap by demonstrating that there is a large class of functions which cannot be efficiently learned by kernel methods but can be easily learned with gradient descent on a two layer neural network outside the kernel regime by learning representations that are relevant to the target task. We also demonstrate that these representations allow for efficient transfer learning, which is impossible in the kernel regime. Specifically, we consider the problem of learning polynomials which depend on only a few relevant directions, i.e. of the form $f^\star(x) = g(Ux)$ where $U: \R^d \to \R^r$ with $d \gg r$. When the degree of $f^\star$ is $p$, it is known that $n \asymp d^p$ samples are necessary to learn $f^\star$ in the kernel regime. Our primary result is that gradient descent learns a representation of the data which depends only on the directions relevant to $f^\star$. This results in an improved sample complexity of $n\asymp d^2 r + dr^p$. Furthermore, in a transfer learning setup where the data distributions in the source and target domain share the same representation $U$ but have different polynomial heads we show that a popular heuristic for transfer learning has a target sample complexity independent of $d$.

LGJan 30, 2023
Looped Transformers as Programmable Computers

Angeliki Giannou, Shashank Rajput, Jy-yong Sohn et al.

We present a framework for using transformer networks as universal computers by programming them with specific weights and placing them in a loop. Our input sequence acts as a punchcard, consisting of instructions and memory for data read/writes. We demonstrate that a constant number of encoder layers can emulate basic computing blocks, including embedding edit operations, non-linear functions, function calls, program counters, and conditional branches. Using these building blocks, we emulate a small instruction-set computer. This allows us to map iterative algorithms to programs that can be executed by a looped, 13-layer transformer. We show how this transformer, instructed by its input, can emulate a basic calculator, a basic linear algebra library, and in-context learning algorithms that employ backpropagation. Our work highlights the versatility of the attention mechanism, and demonstrates that even shallow transformers can execute full-fledged, general-purpose programs.

LGJun 24, 2022
Computationally Efficient PAC RL in POMDPs with Latent Determinism and Conditional Embeddings

Masatoshi Uehara, Ayush Sekhari, Jason D. Lee et al. · harvard

We study reinforcement learning with function approximation for large-scale Partially Observable Markov Decision Processes (POMDPs) where the state space and observation space are large or even continuous. Particularly, we consider Hilbert space embeddings of POMDP where the feature of latent states and the feature of observations admit a conditional Hilbert space embedding of the observation emission process, and the latent state transition is deterministic. Under the function approximation setup where the optimal latent state-action $Q$-function is linear in the state feature, and the optimal $Q$-function has a gap in actions, we provide a \emph{computationally and statistically efficient} algorithm for finding the \emph{exact optimal} policy. We show our algorithm's computational and statistical complexities scale polynomially with respect to the horizon and the intrinsic dimension of the feature on the observation space. Furthermore, we show both the deterministic latent transitions and gap assumptions are necessary to avoid statistical complexity exponential in horizon or dimension. Since our guarantee does not have an explicit dependence on the size of the state and observation spaces, our algorithm provably scales to large-scale POMDPs.

LGFeb 5, 2023
Offline Minimax Soft-Q-learning Under Realizability and Partial Coverage

Masatoshi Uehara, Nathan Kallus, Jason D. Lee et al. · harvard

In offline reinforcement learning (RL) we have no opportunity to explore so we must make assumptions that the data is sufficient to guide picking a good policy, taking the form of assuming some coverage, realizability, Bellman completeness, and/or hard margin (gap). In this work we propose value-based algorithms for offline RL with PAC guarantees under just partial coverage, specifically, coverage of just a single comparator policy, and realizability of soft (entropy-regularized) Q-function of the single policy and a related function defined as a saddle point of certain minimax optimization problem. This offers refined and generally more lax conditions for offline RL. We further show an analogous result for vanilla Q-functions under a soft margin condition. To attain these guarantees, we leverage novel minimax learning algorithms to accurately estimate soft or vanilla Q-functions with $L^2$-convergence guarantees. Our algorithms' loss functions arise from casting the estimation problems as nonlinear convex optimization problems and Lagrangifying.

LGJul 7, 2023
Teaching Arithmetic to Small Transformers

Nayoung Lee, Kartik Sreenivasan, Jason D. Lee et al.

Large language models like GPT-4 exhibit emergent capabilities across general-purpose tasks, such as basic arithmetic, when trained on extensive text data, even though these tasks are not explicitly encoded by the unsupervised, next-token prediction objective. This study investigates how small transformers, trained from random initialization, can efficiently learn arithmetic operations such as addition, multiplication, and elementary functions like square root, using the next-token prediction objective. We first demonstrate that conventional training data is not the most effective for arithmetic learning, and simple formatting changes can significantly improve accuracy. This leads to sharp phase transitions as a function of training data scale, which, in some cases, can be explained through connections to low-rank matrix completion. Building on prior work, we then train on chain-of-thought style data that includes intermediate step results. Even in the complete absence of pretraining, this approach significantly and simultaneously improves accuracy, sample complexity, and convergence speed. We also study the interplay between arithmetic and text data during training and examine the effects of few-shot prompting, pretraining, and model scale. Additionally, we discuss length generalization challenges. Our work highlights the importance of high-quality, instructive data that considers the particular characteristics of the next-word prediction objective for rapidly eliciting arithmetic capabilities.

LGJul 25, 2023
Settling the Sample Complexity of Online Reinforcement Learning

Zihan Zhang, Yuxin Chen, Jason D. Lee et al.

A central issue lying at the heart of online reinforcement learning (RL) is data efficiency. While a number of recent works achieved asymptotically minimal regret in online RL, the optimality of these results is only guaranteed in a ``large-sample'' regime, imposing enormous burn-in cost in order for their algorithms to operate optimally. How to achieve minimax-optimal regret without incurring any burn-in cost has been an open problem in RL theory. We settle this problem for the context of finite-horizon inhomogeneous Markov decision processes. Specifically, we prove that a modified version of Monotonic Value Propagation (MVP), a model-based algorithm proposed by \cite{zhang2020reinforcement}, achieves a regret on the order of (modulo log factors) \begin{equation*} \min\big\{ \sqrt{SAH^3K}, \,HK \big\}, \end{equation*} where $S$ is the number of states, $A$ is the number of actions, $H$ is the planning horizon, and $K$ is the total number of episodes. This regret matches the minimax lower bound for the entire range of sample size $K\geq 1$, essentially eliminating any burn-in requirement. It also translates to a PAC sample complexity (i.e., the number of episodes needed to yield $\varepsilon$-accuracy) of $\frac{SAH^3}{\varepsilon^2}$ up to log factor, which is minimax-optimal for the full $\varepsilon$-range. Further, we extend our theory to unveil the influences of problem-dependent quantities like the optimal value/cost and certain variances. The key technical innovation lies in the development of a new regret decomposition strategy and a novel analysis paradigm to decouple complicated statistical dependency -- a long-standing challenge facing the analysis of online RL in the sample-hungry regime.

LGSep 30, 2022
Self-Stabilization: The Implicit Bias of Gradient Descent at the Edge of Stability

Alex Damian, Eshaan Nichani, Jason D. Lee

Traditional analyses of gradient descent show that when the largest eigenvalue of the Hessian, also known as the sharpness $S(θ)$, is bounded by $2/η$, training is "stable" and the training loss decreases monotonically. Recent works, however, have observed that this assumption does not hold when training modern neural networks with full batch or large batch gradient descent. Most recently, Cohen et al. (2021) observed two important phenomena. The first, dubbed progressive sharpening, is that the sharpness steadily increases throughout training until it reaches the instability cutoff $2/η$. The second, dubbed edge of stability, is that the sharpness hovers at $2/η$ for the remainder of training while the loss continues decreasing, albeit non-monotonically. We demonstrate that, far from being chaotic, the dynamics of gradient descent at the edge of stability can be captured by a cubic Taylor expansion: as the iterates diverge in direction of the top eigenvector of the Hessian due to instability, the cubic term in the local Taylor expansion of the loss function causes the curvature to decrease until stability is restored. This property, which we call self-stabilization, is a general property of gradient descent and explains its behavior at the edge of stability. A key consequence of self-stabilization is that gradient descent at the edge of stability implicitly follows projected gradient descent (PGD) under the constraint $S(θ) \le 2/η$. Our analysis provides precise predictions for the loss, sharpness, and deviation from the PGD trajectory throughout training, which we verify both empirically in a number of standard settings and theoretically under mild conditions. Our analysis uncovers the mechanism for gradient descent's implicit bias towards stability.

LGJun 3, 2022
Decentralized Optimistic Hyperpolicy Mirror Descent: Provably No-Regret Learning in Markov Games

Wenhao Zhan, Jason D. Lee, Zhuoran Yang

We study decentralized policy learning in Markov games where we control a single agent to play with nonstationary and possibly adversarial opponents. Our goal is to develop a no-regret online learning algorithm that (i) takes actions based on the local information observed by the agent and (ii) is able to find the best policy in hindsight. For such a problem, the nonstationary state transitions due to the varying opponent pose a significant challenge. In light of a recent hardness result \citep{liu2022learning}, we focus on the setting where the opponent's previous policies are revealed to the agent for decision making. With such an information structure, we propose a new algorithm, \underline{D}ecentralized \underline{O}ptimistic hype\underline{R}policy m\underline{I}rror de\underline{S}cent (DORIS), which achieves $\sqrt{K}$-regret in the context of general function approximation, where $K$ is the number of episodes. Moreover, when all the agents adopt DORIS, we prove that their mixture policy constitutes an approximate coarse correlated equilibrium. In particular, DORIS maintains a \textit{hyperpolicy} which is a distribution over the policy space. The hyperpolicy is updated via mirror descent, where the update direction is obtained by an optimistic variant of least-squares policy evaluation. Furthermore, to illustrate the power of our method, we apply DORIS to constrained and vector-valued MDPs, which can be formulated as zero-sum Markov games with a fictitious opponent.

GTMar 3, 2023
Can We Find Nash Equilibria at a Linear Rate in Markov Games?

Zhuoqing Song, Jason D. Lee, Zhuoran Yang

We study decentralized learning in two-player zero-sum discounted Markov games where the goal is to design a policy optimization algorithm for either agent satisfying two properties. First, the player does not need to know the policy of the opponent to update its policy. Second, when both players adopt the algorithm, their joint policy converges to a Nash equilibrium of the game. To this end, we construct a meta algorithm, dubbed as $\texttt{Homotopy-PO}$, which provably finds a Nash equilibrium at a global linear rate. In particular, $\texttt{Homotopy-PO}$ interweaves two base algorithms $\texttt{Local-Fast}$ and $\texttt{Global-Slow}$ via homotopy continuation. $\texttt{Local-Fast}$ is an algorithm that enjoys local linear convergence while $\texttt{Global-Slow}$ is an algorithm that converges globally but at a slower sublinear rate. By switching between these two base algorithms, $\texttt{Global-Slow}$ essentially serves as a ``guide'' which identifies a benign neighborhood where $\texttt{Local-Fast}$ enjoys fast convergence. However, since the exact size of such a neighborhood is unknown, we apply a doubling trick to switch between these two base algorithms. The switching scheme is delicately designed so that the aggregated performance of the algorithm is driven by $\texttt{Local-Fast}$. Furthermore, we prove that $\texttt{Local-Fast}$ and $\texttt{Global-Slow}$ can both be instantiated by variants of optimistic gradient descent/ascent (OGDA) method, which is of independent interest.

LGFeb 9, 2023
Efficient displacement convex optimization with particle gradient descent

Hadi Daneshmand, Jason D. Lee, Chi Jin

Particle gradient descent, which uses particles to represent a probability measure and performs gradient descent on particles in parallel, is widely used to optimize functions of probability measures. This paper considers particle gradient descent with a finite number of particles and establishes its theoretical guarantees to optimize functions that are \emph{displacement convex} in measures. Concretely, for Lipschitz displacement convex functions defined on probability over $\mathbb{R}^d$, we prove that $O(1/ε^2)$ particles and $O(d/ε^4)$ computations are sufficient to find the $ε$-optimal solutions. We further provide improved complexity bounds for optimizing smooth displacement convex functions. We demonstrate the application of our results for function approximation with specific neural architectures with two-dimensional inputs.

LGDec 7, 2022
Reconstructing Training Data from Model Gradient, Provably

Zihan Wang, Jason D. Lee, Qi Lei

Understanding when and how much a model gradient leaks information about the training sample is an important question in privacy. In this paper, we present a surprising result: even without training or memorizing the data, we can fully reconstruct the training samples from a single gradient query at a randomly chosen parameter value. We prove the identifiability of the training data under mild conditions: with shallow or deep neural networks and a wide range of activation functions. We also present a statistically and computationally efficient algorithm based on tensor decomposition to reconstruct the training data. As a provable attack that reveals sensitive training data, our findings suggest potential severe threats to privacy, especially in federated learning.

LGNov 20, 2023
Provably Efficient CVaR RL in Low-rank MDPs

Yulai Zhao, Wenhao Zhan, Xiaoyan Hu et al. · princeton

We study risk-sensitive Reinforcement Learning (RL), where we aim to maximize the Conditional Value at Risk (CVaR) with a fixed risk tolerance $τ$. Prior theoretical work studying risk-sensitive RL focuses on the tabular Markov Decision Processes (MDPs) setting. To extend CVaR RL to settings where state space is large, function approximation must be deployed. We study CVaR RL in low-rank MDPs with nonlinear function approximation. Low-rank MDPs assume the underlying transition kernel admits a low-rank decomposition, but unlike prior linear models, low-rank MDPs do not assume the feature or state-action representation is known. We propose a novel Upper Confidence Bound (UCB) bonus-driven algorithm to carefully balance the interplay between exploration, exploitation, and representation learning in CVaR RL. We prove that our algorithm achieves a sample complexity of $\tilde{O}\left(\frac{H^7 A^2 d^4}{τ^2 ε^2}\right)$ to yield an $ε$-optimal CVaR, where $H$ is the length of each episode, $A$ is the capacity of action space, and $d$ is the dimension of representations. Computational-wise, we design a novel discretized Least-Squares Value Iteration (LSVI) algorithm for the CVaR objective as the planning oracle and show that we can find the near-optimal policy in a polynomial running time with a Maximum Likelihood Estimation oracle. To our knowledge, this is the first provably efficient CVaR RL algorithm in low-rank MDPs.

AIJul 18, 2024
Correcting the Mythos of KL-Regularization: Direct Alignment without Overoptimization via Chi-Squared Preference Optimization

Audrey Huang, Wenhao Zhan, Tengyang Xie et al.

Language model alignment methods such as reinforcement learning from human feedback (RLHF) have led to impressive advances in language model capabilities, but are limited by a widely observed phenomenon known as overoptimization, where the quality of the language model degrades over the course of the alignment process. As the model optimizes performance with respect to an offline reward model, it overfits to inaccuracies and drifts away from preferred responses covered by the data. To discourage such distribution shift, KL-regularization is widely employed in existing offline alignment methods, but overoptimization continues to harm performance. Lending theoretical insight into the source of these empirical observations, we first show that the KL-regularization is too weak to prevent overfitting, then raise the following question: is it possible to design an efficient algorithm that is provably robust to overoptimization? We address this question with a new algorithm for offline alignment, $χ^2$-Preference Optimization ($χ$PO). $χ$PO is a one-line change to Direct Preference Optimization (DPO; Rafailov et al., 2023), which only involves modifying the logarithmic link function in the DPO objective. Despite this minimal change, $χ$PO implicitly implements the principle of pessimism in the face of uncertainty via regularization with the $χ^2$-divergence -- which quantifies uncertainty more effectively than KL-regularization -- and provably alleviates overoptimization, achieving sample-complexity guarantees based on single-policy concentrability -- the gold standard in offline reinforcement learning. $χ$PO's simplicity and strong guarantees make it the first practical and general-purpose offline alignment algorithm that is provably robust to overoptimization.

LGMay 18, 2022
On the Effective Number of Linear Regions in Shallow Univariate ReLU Networks: Convergence Guarantees and Implicit Bias

Itay Safran, Gal Vardi, Jason D. Lee

We study the dynamics and implicit bias of gradient flow (GF) on univariate ReLU neural networks with a single hidden layer in a binary classification setting. We show that when the labels are determined by the sign of a target network with $r$ neurons, with high probability over the initialization of the network and the sampling of the dataset, GF converges in direction (suitably defined) to a network achieving perfect training accuracy and having at most $\mathcal{O}(r)$ linear regions, implying a generalization bound. Unlike many other results in the literature, under an additional assumption on the distribution of the data, our result holds even for mild over-parameterization, where the width is $\tilde{\mathcal{O}}(r)$ and independent of the sample size.

LGOct 13, 2022
From Gradient Flow on Population Loss to Learning with Stochastic Gradient Descent

Satyen Kale, Jason D. Lee, Chris De Sa et al.

Stochastic Gradient Descent (SGD) has been the method of choice for learning large-scale non-convex models. While a general analysis of when SGD works has been elusive, there has been a lot of recent progress in understanding the convergence of Gradient Flow (GF) on the population loss, partly due to the simplicity that a continuous-time analysis buys us. An overarching theme of our paper is providing general conditions under which SGD converges, assuming that GF on the population loss converges. Our main tool to establish this connection is a general converse Lyapunov like theorem, which implies the existence of a Lyapunov potential under mild assumptions on the rates of convergence of GF. In fact, using these potentials, we show a one-to-one correspondence between rates of convergence of GF and geometrical properties of the underlying objective. When these potentials further satisfy certain self-bounding properties, we show that they can be used to provide a convergence guarantee for Gradient Descent (GD) and SGD (even when the paths of GF and GD/SGD are quite far apart). It turns out that these self-bounding assumptions are in a sense also necessary for GD/SGD to work. Using our framework, we provide a unified analysis for GD/SGD not only for classical settings like convex losses, or objectives that satisfy PL / KL properties, but also for more complex problems including Phase Retrieval and Matrix sq-root, and extending the results in the recent work of Chatterjee 2022.

LGJun 8, 2022
Identifying good directions to escape the NTK regime and efficiently learn low-degree plus sparse polynomials

Eshaan Nichani, Yu Bai, Jason D. Lee

A recent goal in the theory of deep learning is to identify how neural networks can escape the "lazy training," or Neural Tangent Kernel (NTK) regime, where the network is coupled with its first order Taylor expansion at initialization. While the NTK is minimax optimal for learning dense polynomials (Ghorbani et al, 2021), it cannot learn features, and hence has poor sample complexity for learning many classes of functions including sparse polynomials. Recent works have thus aimed to identify settings where gradient based algorithms provably generalize better than the NTK. One such example is the "QuadNTK" approach of Bai and Lee (2020), which analyzes the second-order term in the Taylor expansion. Bai and Lee (2020) show that the second-order term can learn sparse polynomials efficiently; however, it sacrifices the ability to learn general dense polynomials. In this paper, we analyze how gradient descent on a two-layer neural network can escape the NTK regime by utilizing a spectral characterization of the NTK (Montanari and Zhong, 2020) and building on the QuadNTK approach. We first expand upon the spectral analysis to identify "good" directions in parameter space in which we can move without harming generalization. Next, we show that a wide two-layer neural network can jointly use the NTK and QuadNTK to fit target functions consisting of a dense low-degree term and a sparse high-degree term -- something neither the NTK nor the QuadNTK can do on their own. Finally, we construct a regularizer which encourages our parameter vector to move in the "good" directions, and show that gradient descent on the regularized loss will converge to a global minimizer, which also has low test error. This yields an end to end convergence and generalization guarantee with provable sample complexity improvement over both the NTK and QuadNTK on their own.

CLJul 5, 2023
Scaling In-Context Demonstrations with Structured Attention

Tianle Cai, Kaixuan Huang, Jason D. Lee et al.

The recent surge of large language models (LLMs) highlights their ability to perform in-context learning, i.e., "learning" to perform a task from a few demonstrations in the context without any parameter updates. However, their capabilities of in-context learning are limited by the model architecture: 1) the use of demonstrations is constrained by a maximum sentence length due to positional embeddings; 2) the quadratic complexity of attention hinders users from using more demonstrations efficiently; 3) LLMs are shown to be sensitive to the order of the demonstrations. In this work, we tackle these challenges by proposing a better architectural design for in-context learning. We propose SAICL (Structured Attention for In-Context Learning), which replaces the full-attention by a structured attention mechanism designed for in-context learning, and removes unnecessary dependencies between individual demonstrations, while making the model invariant to the permutation of demonstrations. We evaluate SAICL in a meta-training framework and show that SAICL achieves comparable or better performance than full attention while obtaining up to 3.4x inference speed-up. SAICL also consistently outperforms a strong Fusion-in-Decoder (FiD) baseline which processes each demonstration independently. Finally, thanks to its linear nature, we demonstrate that SAICL can easily scale to hundreds of demonstrations with continuous performance gains with scaling.

LGFeb 22, 2023
Provably Efficient Reinforcement Learning via Surprise Bound

Hanlin Zhu, Ruosong Wang, Jason D. Lee

Value function approximation is important in modern reinforcement learning (RL) problems especially when the state space is (infinitely) large. Despite the importance and wide applicability of value function approximation, its theoretical understanding is still not as sophisticated as its empirical success, especially in the context of general function approximation. In this paper, we propose a provably efficient RL algorithm (both computationally and statistically) with general value function approximations. We show that if the value functions can be approximated by a function class that satisfies the Bellman-completeness assumption, our algorithm achieves an $\widetilde{O}(\text{poly}(ιH)\sqrt{T})$ regret bound where $ι$ is the product of the surprise bound and log-covering numbers, $H$ is the planning horizon, $K$ is the number of episodes and $T = HK$ is the total number of steps the agent interacts with the environment. Our algorithm achieves reasonable regret bounds when applied to both the linear setting and the sparse high-dimensional linear setting. Moreover, our algorithm only needs to solve $O(H\log K)$ empirical risk minimization (ERM) problems, which is far more efficient than previous algorithms that need to solve ERM problems for $Ω(HK)$ times.

LGJun 21, 2023
Sample Complexity for Quadratic Bandits: Hessian Dependent Bounds and Optimal Algorithms

Qian Yu, Yining Wang, Baihe Huang et al.

In stochastic zeroth-order optimization, a problem of practical relevance is understanding how to fully exploit the local geometry of the underlying objective function. We consider a fundamental setting in which the objective function is quadratic, and provide the first tight characterization of the optimal Hessian-dependent sample complexity. Our contribution is twofold. First, from an information-theoretic point of view, we prove tight lower bounds on Hessian-dependent complexities by introducing a concept called energy allocation, which captures the interaction between the searching algorithm and the geometry of objective functions. A matching upper bound is obtained by solving the optimal energy spectrum. Then, algorithmically, we show the existence of a Hessian-independent algorithm that universally achieves the asymptotic optimal sample complexities for all Hessian instances. The optimal sample complexities achieved by our algorithm remain valid for heavy-tailed noise distributions, which are enabled by a truncation method.

LGMar 27
Sharp Capacity Scaling of Spectral Optimizers in Learning Associative Memory

Juno Kim, Eshaan Nichani, Denny Wu et al.

Spectral optimizers such as Muon have recently shown strong empirical performance in large-scale language model training, but the source and extent of their advantage remain poorly understood. We study this question through the linear associative memory problem, a tractable model for factual recall in transformer-based models. In particular, we go beyond orthogonal embeddings and consider Gaussian inputs and outputs, which allows the number of stored associations to greatly exceed the embedding dimension. Our main result sharply characterizes the recovery rates of one step of Muon and SGD on the logistic regression loss under a power law frequency distribution. We show that the storage capacity of Muon significantly exceeds that of SGD, and moreover Muon saturates at a larger critical batch size. We further analyze the multi-step dynamics under a thresholded gradient approximation and show that Muon achieves a substantially faster initial recovery rate than SGD, while both methods eventually converge to the information-theoretic limit at comparable speeds. Experiments on synthetic tasks validate the predicted scaling laws. Our analysis provides a quantitative understanding of the signal amplification of Muon and lays the groundwork for establishing scaling laws across more practical language modeling tasks and optimizers.

LGFeb 18, 2024Code
Revisiting Zeroth-Order Optimization for Memory-Efficient LLM Fine-Tuning: A Benchmark

Yihua Zhang, Pingzhi Li, Junyuan Hong et al.

In the evolving landscape of natural language processing (NLP), fine-tuning pre-trained Large Language Models (LLMs) with first-order (FO) optimizers like SGD and Adam has become standard. Yet, as LLMs grow {in size}, the substantial memory overhead from back-propagation (BP) for FO gradient computation presents a significant challenge. Addressing this issue is crucial, especially for applications like on-device training where memory efficiency is paramount. This paper proposes a shift towards BP-free, zeroth-order (ZO) optimization as a solution for reducing memory costs during LLM fine-tuning, building on the initial concept introduced by MeZO. Unlike traditional ZO-SGD methods, our work expands the exploration to a wider array of ZO optimization techniques, through a comprehensive, first-of-its-kind benchmarking study across five LLM families (Roberta, OPT, LLaMA, Vicuna, Mistral), three task complexities, and five fine-tuning schemes. Our study unveils previously overlooked optimization principles, highlighting the importance of task alignment, the role of the forward gradient method, and the balance between algorithm complexity and fine-tuning performance. We further introduce novel enhancements to ZO optimization, including block-wise descent, hybrid training, and gradient sparsity. Our study offers a promising direction for achieving further memory-efficient LLM fine-tuning. Codes to reproduce all our experiments are at https://github.com/ZO-Bench/ZO-LLM .

LGNov 23, 2023
Learning Hierarchical Polynomials with Three-Layer Neural Networks

Zihao Wang, Eshaan Nichani, Jason D. Lee

We study the problem of learning hierarchical polynomials over the standard Gaussian distribution with three-layer neural networks. We specifically consider target functions of the form $h = g \circ p$ where $p : \mathbb{R}^d \rightarrow \mathbb{R}$ is a degree $k$ polynomial and $g: \mathbb{R} \rightarrow \mathbb{R}$ is a degree $q$ polynomial. This function class generalizes the single-index model, which corresponds to $k=1$, and is a natural class of functions possessing an underlying hierarchical structure. Our main result shows that for a large subclass of degree $k$ polynomials $p$, a three-layer neural network trained via layerwise gradient descent on the square loss learns the target $h$ up to vanishing test error in $\widetilde{\mathcal{O}}(d^k)$ samples and polynomial time. This is a strict improvement over kernel methods, which require $\widetilde Θ(d^{kq})$ samples, as well as existing guarantees for two-layer networks, which require the target function to be low-rank. Our result also generalizes prior works on three-layer neural networks, which were restricted to the case of $p$ being a quadratic. When $p$ is indeed a quadratic, we achieve the information-theoretically optimal sample complexity $\widetilde{\mathcal{O}}(d^2)$, which is an improvement over prior work~\citep{nichani2023provable} requiring a sample size of $\widetildeΘ(d^4)$. Our proof proceeds by showing that during the initial stage of training the network performs feature learning to recover the feature $p$ with $\widetilde{\mathcal{O}}(d^k)$ samples. This work demonstrates the ability of three-layer neural networks to learn complex features and as a result, learn a broad class of hierarchical functions.

LGApr 12, 2024Code
Dataset Reset Policy Optimization for RLHF

Jonathan D. Chang, Wenhao Zhan, Owen Oertell et al.

Reinforcement Learning (RL) from Human Preference-based feedback is a popular paradigm for fine-tuning generative models, which has produced impressive models such as GPT-4 and Claude3 Opus. This framework often consists of two steps: learning a reward model from an offline preference dataset followed by running online RL to optimize the learned reward model. In this work, leveraging the idea of reset, we propose a new RLHF algorithm with provable guarantees. Motivated by the fact that offline preference dataset provides informative states (i.e., data that is preferred by the labelers), our new algorithm, Dataset Reset Policy Optimization (DR-PO), integrates the existing offline preference dataset into the online policy training procedure via dataset reset: it directly resets the policy optimizer to the states in the offline dataset, instead of always starting from the initial state distribution. In theory, we show that DR-PO learns to perform at least as good as any policy that is covered by the offline dataset under general function approximation with finite sample complexity. In experiments, we demonstrate that on both the TL;DR summarization and the Anthropic Helpful Harmful (HH) dataset, the generation from DR-PO is better than that from Proximal Policy Optimization (PPO) and Direction Preference Optimization (DPO), under the metric of GPT4 win-rate. Code for this work can be found at https://github.com/Cornell-RL/drpo.

LGMay 27, 2025Code
Accelerating RL for LLM Reasoning with Optimal Advantage Regression

Kianté Brantley, Mingyu Chen, Zhaolin Gao et al.

Reinforcement learning (RL) has emerged as a powerful tool for fine-tuning large language models (LLMs) to improve complex reasoning abilities. However, state-of-the-art policy optimization methods often suffer from high computational overhead and memory consumption, primarily due to the need for multiple generations per prompt and the reliance on critic networks or advantage estimates of the current policy. In this paper, we propose $A$*-PO, a novel two-stage policy optimization framework that directly approximates the optimal advantage function and enables efficient training of LLMs for reasoning tasks. In the first stage, we leverage offline sampling from a reference policy to estimate the optimal value function $V$*, eliminating the need for costly online value estimation. In the second stage, we perform on-policy updates using a simple least-squares regression loss with only a single generation per prompt. Theoretically, we establish performance guarantees and prove that the KL-regularized RL objective can be optimized without requiring complex exploration strategies. Empirically, $A$*-PO achieves competitive performance across a wide range of mathematical reasoning benchmarks, while reducing training time by up to 2$\times$ and peak memory usage by over 30% compared to PPO, GRPO, and REBEL. Implementation of $A$*-PO can be found at https://github.com/ZhaolinGao/A-PO.

LGFeb 2Code
Statistical Learning Theory in Lean 4: Empirical Processes from Scratch

Yuanhe Zhang, Jason D. Lee, Fanghui Liu

We present the first comprehensive Lean 4 formalization of statistical learning theory (SLT) grounded in empirical process theory. Our end-to-end formal infrastructure implement the missing contents in latest Lean 4 Mathlib library, including a complete development of Gaussian Lipschitz concentration, the first formalization of Dudley's entropy integral theorem for sub-Gaussian processes, and an application to least-squares (sparse) regression with a sharp rate. The project was carried out using a human-AI collaborative workflow, in which humans design proof strategies and AI agents execute tactical proof construction, leading to the human-verified Lean 4 toolbox for SLT. Beyond implementation, the formalization process exposes and resolves implicit assumptions and missing details in standard SLT textbooks, enforcing a granular, line-by-line understanding of the theory. This work establishes a reusable formal foundation and opens the door for future developments in machine learning theory. The code is available at https://github.com/YuanheZ/lean-stat-learning-theory

LGOct 30, 2025
Quantitative Bounds for Length Generalization in Transformers

Zachary Izzo, Eshaan Nichani, Jason D. Lee

We study the problem of length generalization (LG) in transformers: the ability of a model trained on shorter sequences to maintain performance when evaluated on much longer, previously unseen inputs. Prior work by Huang et al. (2025) established that transformers eventually achieve length generalization once the training sequence length exceeds some finite threshold, but left open the question of how large it must be. In this work, we provide the first quantitative bounds on the required training length for length generalization to occur. Motivated by previous empirical and theoretical work, we analyze LG in several distinct problem settings: $\ell_\infty$ error control vs. average error control over an input distribution, infinite-precision softmax attention vs. finite-precision attention (which reduces to an argmax) in the transformer, and one- vs. two-layer transformers. In all scenarios, we prove that LG occurs when the internal behavior of the transformer on longer sequences can be "simulated" by its behavior on shorter sequences seen during training. Our bounds give qualitative estimates for the length of training data required for a transformer to generalize, and we verify these insights empirically. These results sharpen our theoretical understanding of the mechanisms underlying extrapolation in transformers, and formalize the intuition that richer training data is required for generalization on more complex tasks.

CLJan 1, 2025Code
Rethinking Addressing in Language Models via Contexualized Equivariant Positional Encoding

Jiajun Zhu, Peihao Wang, Ruisi Cai et al.

Transformers rely on both content-based and position-based addressing mechanisms to make predictions, but existing positional encoding techniques often diminish the effectiveness of position-based addressing. Many current methods enforce rigid patterns in attention maps, limiting the ability to model long-range dependencies and adapt to diverse tasks. Additionally, most positional encodings are learned as general biases, lacking the specialization required for different instances within a dataset. To address this, we propose con\textbf{T}extualized equivari\textbf{A}nt \textbf{P}osition \textbf{E}ncoding (\textbf{TAPE}), a novel framework that enhances positional embeddings by incorporating sequence content across layers. TAPE introduces dynamic, context-aware positional encodings, overcoming the constraints of traditional fixed patterns. We show that TAPE can provably facilitate LLM reasoning ability by emulating a broader class of algorithms. By enforcing permutation and orthogonal equivariance, TAPE ensures the stability of positional encodings during updates, improving long-context ability. Our method can be easily integrated into pre-trained transformers, offering parameter-efficient fine-tuning with minimal overhead. Extensive experiments show that TAPE achieves superior performance in language modeling, arithmetic reasoning, and long-context retrieval tasks compared to existing positional embedding techniques. Code is available at https://github.com/VITA-Group/TAPE.

LGJan 27
Provable Learning of Random Hierarchy Models and Hierarchical Shallow-to-Deep Chaining

Yunwei Ren, Yatin Dandi, Florent Krzakala et al.

The empirical success of deep learning is often attributed to deep networks' ability to exploit hierarchical structure in data, constructing increasingly complex features across layers. Yet despite substantial progress in deep learning theory, most optimization results sill focus on networks with only two or three layers, leaving the theoretical understanding of hierarchical learning in genuinely deep models limited. This leads to a natural question: can we prove that deep networks, trained by gradient-based methods, can efficiently exploit hierarchical structure? In this work, we consider Random Hierarchy Models -- a hierarchical context-free grammar introduced by arXiv:2307.02129 and conjectured to separate deep and shallow networks. We prove that, under mild conditions, a deep convolutional network can be efficiently trained to learn this function class. Our proof builds on a general observation: if intermediate layers can receive clean signal from the labels and the relevant features are weakly identifiable, then layerwise training each individual layer suffices to hierarchically learn the target function.

AIOct 19, 2025Code
DAG-Math: Graph-Guided Mathematical Reasoning in LLMs

Yuanhe Zhang, Ilja Kuzborskij, Jason D. Lee et al.

Large Language Models (LLMs) demonstrate strong performance on mathematical problems when prompted with Chain-of-Thought (CoT), yet it remains unclear whether this success stems from search, rote procedures, or rule-consistent reasoning. To address this, we propose modeling CoT as a certain rule-based stochastic process over directed acyclic graphs (DAGs), where nodes represent intermediate derivation states and edges encode rule applications. Within this framework, we introduce logical closeness, a metric that quantifies how well a model's CoT trajectory (i.e., the LLM's final output) adheres to the DAG structure, providing evaluation beyond classical PASS@k metrics. Building on this, we introduce the DAG-MATH CoT format and construct a benchmark that guides LLMs to generate CoT trajectories in this format, thereby enabling the evaluation of their reasoning ability under our framework. Across standard mathematical reasoning datasets, our analysis uncovers statistically significant differences in reasoning fidelity among representative LLM families-even when PASS@k is comparable-highlighting gaps between final-answer accuracy and rule-consistent derivation. Our framework provides a balance between free-form CoT and formal proofs systems, offering actionable diagnostics for LLMs reasoning evaluation. Our benchmark and code are available at: https://github.com/YuanheZ/DAG-MATH-Formatted-CoT.

LGSep 22, 2020Code
Sanity-Checking Pruning Methods: Random Tickets can Win the Jackpot

Jingtong Su, Yihang Chen, Tianle Cai et al.

Network pruning is a method for reducing test-time computational resource requirements with minimal performance degradation. Conventional wisdom of pruning algorithms suggests that: (1) Pruning methods exploit information from training data to find good subnetworks; (2) The architecture of the pruned network is crucial for good performance. In this paper, we conduct sanity checks for the above beliefs on several recent unstructured pruning methods and surprisingly find that: (1) A set of methods which aims to find good subnetworks of the randomly-initialized network (which we call "initial tickets"), hardly exploits any information from the training data; (2) For the pruned networks obtained by these methods, randomly changing the preserved weights in each layer, while keeping the total number of preserved weights unchanged per layer, does not affect the final performance. These findings inspire us to choose a series of simple \emph{data-independent} prune ratios for each layer, and randomly prune each layer accordingly to get a subnetwork (which we call "random tickets"). Experimental results show that our zero-shot random tickets outperform or attain a similar performance compared to existing "initial tickets". In addition, we identify one existing pruning method that passes our sanity checks. We hybridize the ratios in our random ticket with this method and propose a new method called "hybrid tickets", which achieves further improvement. (Our code is publicly available at https://github.com/JingtongSu/sanity-checking-pruning)

LGFeb 22, 2024
How Transformers Learn Causal Structure with Gradient Descent

Eshaan Nichani, Alex Damian, Jason D. Lee

The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.

MLMay 6
Sharp Capacity Thresholds in Linear Associative Memory: From Winner-Take-All to Listwise Retrieval

Nicholas Barnfield, Juno Kim, Eshaan Nichani et al.

How many key-value associations can a $d\times d$ linear memory store? We show that the answer depends not only on the $d^2$ degrees of freedom in the memory matrix, but also on the retrieval criterion. In an isotropic Gaussian model for the stored pairs, we show that top-1 retrieval, where every signal must beat its largest distractor, requires the logarithmic model-size scale $d^2\asymp n\log n$. We prove that the correlation matrix memory construction, which stores associations by superposing key-target outer products, achieves this scale through a sharp phase transition, and that the same scaling is necessary for any linear memory. Thus the logarithm is the intrinsic extreme-value price of winner-take-all decoding. We next consider listwise retrieval, where the correct target need not be the unique top-scoring item but should remain among the strongest candidates. To formalize this regime, we propose the Tail-Average Margin (TAM), a convex upper-tail criterion that certifies inclusion of the correct target in a controlled candidate list. Under this listwise retrieval criterion, the capacity follows the quadratic scale $d^2\asymp n$. At load $n/d^2\toα$, we develop an exact asymptotic theory for the TAM empirical-risk minimizer through a two-parameter scalar variational principle. The theory has a rich phenomenology: in the ridgeless limit it yields a closed-form critical load separating satisfiable and unsatisfiable phases, and it predicts the limiting laws of true scores, competitor scores, margins, and percentile profiles. Finally, a small-tail extrapolation further leads to the conjectural sharp top-1 threshold $d^2\sim 2n\log n$.

LGMar 6
Improved high-dimensional estimation with Langevin dynamics and stochastic weight averaging

Stanley Wei, Alex Damian, Jason D. Lee

Significant recent work has studied the ability of gradient descent to recover a hidden planted direction $θ^\star \in S^{d-1}$ in different high-dimensional settings, including tensor PCA and single-index models. The key quantity that governs the ability of gradient descent to traverse these landscapes is the information exponent $k^\star$ (Ben Arous et al., (2021)), which corresponds to the order of the saddle at initialization in the population landscape. Ben Arous et al., (2021) showed that $n \gtrsim d^{\max(1, k^\star-1)}$ samples were necessary and sufficient for online SGD to recover $θ^\star$, and Ben Arous et al., (2020) proved a similar lower bound for Langevin dynamics. More recently, Damian et al., (2023) showed it was possible to circumvent these lower bounds by running gradient descent on a smoothed landscape, and that this algorithm succeeds with $n \gtrsim d^{\max(1, k^\star/2)}$ samples, which is optimal in the worst case. This raises the question of whether it is possible to achieve the same rate without explicit smoothing. In this paper, we show that Langevin dynamics can succeed with $n \gtrsim d^{ k^\star/2 }$ samples if one considers the average iterate, rather than the last iterate. The key idea is that the combination of noise-injection and iterate averaging is able to emulate the effect of landscape smoothing. We apply this result to both the tensor PCA and single-index model settings. Finally, we conjecture that minibatch SGD can also achieve the same rate without adding any additional noise.

LGApr 25, 2024
REBEL: Reinforcement Learning via Regressing Relative Rewards

Zhaolin Gao, Jonathan D. Chang, Wenhao Zhan et al.

While originally developed for continuous control problems, Proximal Policy Optimization (PPO) has emerged as the work-horse of a variety of reinforcement learning (RL) applications, including the fine-tuning of generative models. Unfortunately, PPO requires multiple heuristics to enable stable convergence (e.g. value networks, clipping), and is notorious for its sensitivity to the precise implementation of these components. In response, we take a step back and ask what a minimalist RL algorithm for the era of generative models would look like. We propose REBEL, an algorithm that cleanly reduces the problem of policy optimization to regressing the relative reward between two completions to a prompt in terms of the policy, enabling strikingly lightweight implementation. In theory, we prove that fundamental RL algorithms like Natural Policy Gradient can be seen as variants of REBEL, which allows us to match the strongest known theoretical guarantees in terms of convergence and sample complexity in the RL literature. REBEL can also cleanly incorporate offline data and be extended to handle the intransitive preferences we frequently see in practice. Empirically, we find that REBEL provides a unified approach to language modeling and image generation with stronger or similar performance as PPO and DPO, all while being simpler to implement and more computationally efficient than PPO. When fine-tuning Llama-3-8B-Instruct, REBEL achieves strong performance in AlpacaEval 2.0, MT-Bench, and Open LLM Leaderboard.

LGMar 19, 2025
What Makes a Reward Model a Good Teacher? An Optimization Perspective

Noam Razin, Zixuan Wang, Hubert Strauss et al. · princeton

The success of Reinforcement Learning from Human Feedback (RLHF) critically depends on the quality of the reward model. However, while this quality is primarily evaluated through accuracy, it remains unclear whether accuracy fully captures what makes a reward model an effective teacher. We address this question from an optimization perspective. First, we prove that regardless of how accurate a reward model is, if it induces low reward variance, then the RLHF objective suffers from a flat landscape. Consequently, even a perfectly accurate reward model can lead to extremely slow optimization, underperforming less accurate models that induce higher reward variance. We additionally show that a reward model that works well for one language model can induce low reward variance, and thus a flat objective landscape, for another. These results establish a fundamental limitation of evaluating reward models solely based on accuracy or independently of the language model they guide. Experiments using models of up to 8B parameters corroborate our theory, demonstrating the interplay between reward variance, accuracy, and reward maximization rate. Overall, our findings highlight that beyond accuracy, a reward model needs to induce sufficient variance for efficient~optimization.

LGFeb 15, 2024
BitDelta: Your Fine-Tune May Only Be Worth One Bit

James Liu, Guangxuan Xiao, Kai Li et al.

Large Language Models (LLMs) are typically trained in two phases: pre-training on large internet-scale datasets, and fine-tuning for downstream tasks. Given the higher computational demand of pre-training, it's intuitive to assume that fine-tuning adds less new information to the model, and is thus more compressible. We explore this assumption by decomposing the weights of fine-tuned models into their pre-trained components and an additional delta. We introduce a simple method, BitDelta, which successfully quantizes this delta down to 1 bit without compromising performance. This interesting finding not only highlights the potential redundancy of information added during fine-tuning, but also has significant implications for the multi-tenant serving and multi-tenant storage of fine-tuned models. By enabling the use of a single high-precision base model accompanied by multiple 1-bit deltas, BitDelta dramatically reduces GPU memory requirements by more than 10x, which can also be translated to enhanced generation latency in multi-tenant settings. We validate BitDelta through experiments across Llama-2 and Mistral model families, and on models up to 70B parameters, showcasing minimal performance degradation over all tested settings.

LGJan 28, 2024
An Information-Theoretic Analysis of In-Context Learning

Hong Jun Jeon, Jason D. Lee, Qi Lei et al.

Previous theoretical results pertaining to meta-learning on sequences build on contrived assumptions and are somewhat convoluted. We introduce new information-theoretic tools that lead to an elegant and very general decomposition of error into three components: irreducible error, meta-learning error, and intra-task error. These tools unify analyses across many meta-learning challenges. To illustrate, we apply them to establish new results about in-context learning with transformers. Our theoretical results characterizes how error decays in both the number of training sequences and sequence lengths. Our results are very general; for example, they avoid contrived mixing time assumptions made by all prior results that establish decay of error with sequence length.

LGMar 5, 2024
How Well Can Transformers Emulate In-context Newton's Method?

Angeliki Giannou, Liu Yang, Tianhao Wang et al.

Transformer-based models have demonstrated remarkable in-context learning capabilities, prompting extensive research into its underlying mechanisms. Recent studies have suggested that Transformers can implement first-order optimization algorithms for in-context learning and even second order ones for the case of linear regression. In this work, we study whether Transformers can perform higher order optimization methods, beyond the case of linear regression. We establish that linear attention Transformers with ReLU layers can approximate second order optimization algorithms for the task of logistic regression and achieve $ε$ error with only a logarithmic to the error more layers. As a by-product we demonstrate the ability of even linear attention-only Transformers in implementing a single step of Newton's iteration for matrix inversion with merely two layers. These results suggest the ability of the Transformer architecture to implement complex algorithms, beyond gradient descent.

LGFeb 19, 2024
LoRA Training in the NTK Regime has No Spurious Local Minima

Uijeong Jang, Jason D. Lee, Ernest K. Ryu

Low-rank adaptation (LoRA) has become the standard approach for parameter-efficient fine-tuning of large language models (LLM), but our theoretical understanding of LoRA has been limited. In this work, we theoretically analyze LoRA fine-tuning in the neural tangent kernel (NTK) regime with $N$ data points, showing: (i) full fine-tuning (without LoRA) admits a low-rank solution of rank $r\lesssim \sqrt{N}$; (ii) using LoRA with rank $r\gtrsim \sqrt{N}$ eliminates spurious local minima, allowing gradient descent to find the low-rank solutions; (iii) the low-rank solution found using LoRA generalizes well.

LGDec 9, 2024
Understanding Factual Recall in Transformers via Associative Memories

Eshaan Nichani, Jason D. Lee, Alberto Bietti

Large language models have demonstrated an impressive ability to perform factual recall. Prior work has found that transformers trained on factual recall tasks can store information at a rate proportional to their parameter count. In our work, we show that shallow transformers can use a combination of associative memories to obtain such near optimal storage capacity. We begin by proving that the storage capacities of both linear and MLP associative memories scale linearly with parameter count. We next introduce a synthetic factual recall task, and prove that a transformer with a single layer of self-attention followed by an MLP can obtain 100% accuracy on the task whenever either the total number of self-attention parameters or MLP parameters scales (up to log factors) linearly with the number of facts. In particular, the transformer can trade off between using the value matrices or the MLP as an associative memory to store the dataset of facts. We complement these expressivity results with an analysis of the gradient flow trajectory of a simplified linear attention model trained on our factual recall task, where we show that the model exhibits sequential learning behavior.

LGMar 8, 2024
Computational-Statistical Gaps in Gaussian Single-Index Models

Alex Damian, Loucas Pillaud-Vivien, Jason D. Lee et al.

Single-Index Models are high-dimensional regression problems with planted structure, whereby labels depend on an unknown one-dimensional projection of the input via a generic, non-linear, and potentially non-deterministic transformation. As such, they encompass a broad class of statistical inference tasks, and provide a rich template to study statistical and computational trade-offs in the high-dimensional regime. While the information-theoretic sample complexity to recover the hidden direction is linear in the dimension $d$, we show that computationally efficient algorithms, both within the Statistical Query (SQ) and the Low-Degree Polynomial (LDP) framework, necessarily require $Ω(d^{k^\star/2})$ samples, where $k^\star$ is a "generative" exponent associated with the model that we explicitly characterize. Moreover, we show that this sample complexity is also sufficient, by establishing matching upper bounds using a partial-trace algorithm. Therefore, our results provide evidence of a sharp computational-to-statistical gap (under both the SQ and LDP class) whenever $k^\star>2$. To complete the study, we provide examples of smooth and Lipschitz deterministic target functions with arbitrarily large generative exponents $k^\star$.