Vaishnavh Nagarajan

LG
h-index28
22papers
1,946citations
Novelty55%
AI Score47

22 Papers

CLOct 3, 2023
Think before you speak: Training Language Models With Pause Tokens

Sachin Goyal, Ziwei Ji, Ankit Singh Rawat et al.

Language models generate responses by producing a series of tokens in immediate succession: the $(K+1)^{th}$ token is an outcome of manipulating $K$ hidden vectors per layer, one vector per preceding token. What if instead we were to let the model manipulate say, $K+10$ hidden vectors, before it outputs the $(K+1)^{th}$ token? We operationalize this idea by performing training and inference on language models with a (learnable) $\textit{pause}$ token, a sequence of which is appended to the input prefix. We then delay extracting the model's outputs until the last pause token is seen, thereby allowing the model to process extra computation before committing to an answer. We empirically evaluate $\textit{pause-training}$ on decoder-only models of 1B and 130M parameters with causal pretraining on C4, and on downstream tasks covering reasoning, question-answering, general understanding and fact recall. Our main finding is that inference-time delays show gains when the model is both pre-trained and finetuned with delays. For the 1B model, we witness gains on 8 of 9 tasks, most prominently, a gain of $18\%$ EM score on the QA task of SQuAD, $8\%$ on CommonSenseQA and $1\%$ accuracy on the reasoning task of GSM8k. Our work raises a range of conceptual and practical future research questions on making delayed next-token prediction a widely applicable new paradigm.

LGJan 30, 2023
On student-teacher deviations in distillation: does it pay to disobey?

Vaishnavh Nagarajan, Aditya Krishna Menon, Srinadh Bhojanapalli et al.

Knowledge distillation (KD) has been widely used to improve the test accuracy of a "student" network, by training it to mimic the soft probabilities of a trained "teacher" network. Yet, it has been shown in recent work that, despite being trained to fit the teacher's probabilities, the student may not only significantly deviate from the teacher probabilities, but may also outdo than the teacher in performance. Our work aims to reconcile this seemingly paradoxical observation. Specifically, we characterize the precise nature of the student-teacher deviations, and argue how they can co-occur with better generalization. First, through experiments on image and language data, we identify that these probability deviations correspond to the student systematically exaggerating the confidence levels of the teacher. Next, we theoretically and empirically establish another form of exaggeration in some simple settings: KD exaggerates the implicit bias of gradient descent in converging faster along the top eigendirections of the data. Finally, we tie these two observations together: we demonstrate that the exaggerated bias of KD can simultaneously result in both (a) the exaggeration of confidence and (b) the improved generalization of the student, thus offering a resolution to the apparent paradox. Our analysis brings existing theory and practice closer by considering the role of gradient descent in KD and by demonstrating the exaggerated bias effect in both theoretical and empirical settings.

LGFeb 3, 2023
ResMem: Learn what you can and memorize the rest

Zitong Yang, Michal Lukasik, Vaishnavh Nagarajan et al.

The impressive generalization performance of modern neural networks is attributed in part to their ability to implicitly memorize complex training patterns. Inspired by this, we explore a novel mechanism to improve model generalization via explicit memorization. Specifically, we propose the residual-memorization (ResMem) algorithm, a new method that augments an existing prediction model (e.g. a neural network) by fitting the model's residuals with a $k$-nearest neighbor based regressor. The final prediction is then the sum of the original model and the fitted residual regressor. By construction, ResMem can explicitly memorize the training labels. Empirically, we show that ResMem consistently improves the test set generalization of the original prediction model across various standard vision and natural language processing benchmarks. Theoretically, we formulate a stylized linear regression problem and rigorously show that ResMem results in a more favorable test risk over the base predictor.

LGOct 9, 2023
What do larger image classifiers memorise?

Michal Lukasik, Vaishnavh Nagarajan, Ankit Singh Rawat et al.

The success of modern neural networks has prompted study of the connection between memorisation and generalisation: overparameterised models generalise well, despite being able to perfectly fit (memorise) completely random labels. To carefully study this issue, Feldman proposed a metric to quantify the degree of memorisation of individual training examples, and empirically computed the corresponding memorisation profile of a ResNet on image classification bench-marks. While an exciting first glimpse into what real-world models memorise, this leaves open a fundamental question: do larger neural models memorise more? We present a comprehensive empirical analysis of this question on image classification benchmarks. We find that training examples exhibit an unexpectedly diverse set of memorisation trajectories across model sizes: most samples experience decreased memorisation under larger models, while the rest exhibit cap-shaped or increasing memorisation. We show that various proxies for the Feldman memorization score fail to capture these fundamental trends. Lastly, we find that knowledge distillation, an effective and popular model compression technique, tends to inhibit memorisation, while also improving generalisation. Specifically, memorisation is mostly inhibited on examples with increasing memorisation trajectories, thus pointing at how distillation improves generalisation.

CLOct 7, 2023
The Cost of Down-Scaling Language Models: Fact Recall Deteriorates before In-Context Learning

Tian Jin, Nolan Clement, Xin Dong et al.

How does scaling the number of parameters in large language models (LLMs) affect their core capabilities? We study two natural scaling techniques -- weight pruning and simply training a smaller or larger model, which we refer to as dense scaling -- and their effects on two core capabilities of LLMs: (a) recalling facts presented during pre-training and (b) processing information presented in-context during inference. By curating a suite of tasks that help disentangle these two capabilities, we find a striking difference in how these two abilities evolve due to scaling. Reducing the model size by more than 30\% (via either scaling approach) significantly decreases the ability to recall facts seen in pre-training. Yet, a 60--70\% reduction largely preserves the various ways the model can process in-context information, ranging from retrieving answers from a long context to learning parameterized functions from in-context exemplars. The fact that both dense scaling and weight pruning exhibit this behavior suggests that scaling model size has an inherently disparate effect on fact recall and in-context learning.

CLMar 11, 2024Code
The pitfalls of next-token prediction

Gregor Bachmann, Vaishnavh Nagarajan

Can a mere next-token predictor faithfully model human intelligence? We crystallize this emerging concern and correct popular misconceptions surrounding it, and advocate a simple multi-token objective. As a starting point, we argue that the two often-conflated phases of next-token prediction -- autoregressive inference and teacher-forced training -- must be treated distinctly. The popular criticism that errors can compound during autoregressive inference, crucially assumes that teacher-forcing has learned an accurate next-token predictor. This assumption sidesteps a more deep-rooted problem we expose: in certain classes of tasks, teacher-forcing can simply fail to learn an accurate next-token predictor in the first place. We describe a general mechanism of how teacher-forcing can fail, and design a minimal planning task where both the Transformer and the Mamba architecture empirically fail in that manner -- remarkably, despite the task being straightforward to learn. Finally, we provide preliminary evidence that this failure can be resolved using _teacherless_ training, a simple modification using dummy tokens that predicts multiple tokens in advance. We hope this finding can ground future debates and inspire explorations beyond the next-token prediction paradigm. We make our code available under https://github.com/gregorbachmann/Next-Token-Failures

LGApr 21, 2025Code
Roll the dice & look before you leap: Going beyond the creative limits of next-token prediction

Vaishnavh Nagarajan, Chen Henry Wu, Charles Ding et al. · cmu

We design a suite of minimal algorithmic tasks that are a loose abstraction of open-ended real-world tasks. This allows us to cleanly and controllably quantify the creative limits of the present-day language model. Much like real-world tasks that require a creative, far-sighted leap of thought, our tasks require an implicit, open-ended stochastic planning step that either (a) discovers new connections in an abstract knowledge graph (like in wordplay, drawing analogies, or research) or (b) constructs new patterns (like in designing math problems or new proteins). In these tasks, we empirically and conceptually argue how next-token learning is myopic; multi-token approaches, namely teacherless training and diffusion models, comparatively excel in producing diverse and original output. Secondly, to elicit randomness without hurting coherence, we find that injecting noise at the input layer (dubbed seed-conditioning) works surprisingly as well as (and in some conditions, better than) temperature sampling from the output layer. Thus, our work offers a principled, minimal test-bed for analyzing open-ended creative skills, and offers new arguments for going beyond next-token learning and temperature sampling. We make part of the code available under https://github.com/chenwu98/algorithmic-creativity

LGOct 30, 2025
Deep sequence models tend to memorize geometrically; it is unclear why

Shahriar Noroozizadeh, Vaishnavh Nagarajan, Elan Rosenfeld et al.

In sequence modeling, the parametric memory of atomic facts has been predominantly abstracted as a brute-force lookup of co-occurrences between entities. We contrast this associative view against a geometric view of how memory is stored. We begin by isolating a clean and analyzable instance of Transformer reasoning that is incompatible with memory as strictly a storage of the local co-occurrences specified during training. Instead, the model must have somehow synthesized its own geometry of atomic facts, encoding global relationships between all entities, including non-co-occurring ones. This in turn has simplified a hard reasoning task involving an $\ell$-fold composition into an easy-to-learn 1-step geometric task. From this phenomenon, we extract fundamental aspects of neural embedding geometries that are hard to explain. We argue that the rise of such a geometry, despite optimizing over mere local associations, cannot be straightforwardly attributed to typical architectural or optimizational pressures. Counterintuitively, an elegant geometry is learned even when it is not more succinct than a brute-force lookup of associations. Then, by analyzing a connection to Node2Vec, we demonstrate how the geometry stems from a spectral bias that -- in contrast to prevailing theories -- indeed arises naturally despite the lack of various pressures. This analysis also points to practitioners a visible headroom to make Transformer memory more strongly geometric. We hope the geometric view of parametric memory encourages revisiting the default intuitions that guide researchers in areas like knowledge acquisition, capacity, discovery and unlearning.

CLOct 13, 2025
Catch Your Breath: Adaptive Computation for Self-Paced Sequence Production

Alexandre Galashov, Matt Jones, Rosemary Ke et al. · deepmind

We explore a class of supervised training objectives that allow a language model to dynamically and autonomously scale the number of compute steps used for each input token. For any token, the model can request additional compute steps by emitting a <don't know> output. If the model is granted a delay, a specialized <pause> token is inserted at the next input step, providing the model with additional compute resources to generate an output. The model can request multiple pauses. To train the model to use <don't know> outputs judiciously and to calibrate its uncertainty, we frame the selection of each output token as a sequential-decision problem with a time cost. We refer to the class of methods as $\textit{Catch Your Breath}$ losses and we study three methods in this class: CYB-AP frames the model's task as anytime prediction, where an output may be required at any step and accuracy is discounted over time; CYB-VA is a variational approach that aims to maximize prediction accuracy subject to a specified distribution over stopping times; and CYB-DP imposes a penalty based on a computational budget. Through fine-tuning experiments, we identify the best performing loss variant. The CYB model needs only one third as much training data as the baseline (no pause) model needs to achieve the same performance, and half as much data as a model with pauses and a cross-entropy loss. We find that the CYB model requests additional steps when doing so improves accuracy, and the model adapts its processing time to token-level complexity and context. For example, it often pauses after plural nouns like $\textit{patients}$ and $\textit{challenges}$ but never pauses after the first token of contracted words like $\textit{wasn}$ and $\textit{didn}$, and it shows high variability for ambiguous tokens like $\textit{won}$, which could function as either a verb or part of a contraction.

LGOct 17, 2021
Explaining generalization in deep learning: progress and fundamental limits

Vaishnavh Nagarajan

This dissertation studies a fundamental open challenge in deep learning theory: why do deep networks generalize well even while being overparameterized, unregularized and fitting the training data to zero error? In the first part of the thesis, we will empirically study how training deep networks via stochastic gradient descent implicitly controls the networks' capacity. Subsequently, to show how this leads to better generalization, we will derive {\em data-dependent} {\em uniform-convergence-based} generalization bounds with improved dependencies on the parameter count. Uniform convergence has in fact been the most widely used tool in deep learning literature, thanks to its simplicity and generality. Given its popularity, in this thesis, we will also take a step back to identify the fundamental limits of uniform convergence as a tool to explain generalization. In particular, we will show that in some example overparameterized settings, {\em any} uniform convergence bound will provide only a vacuous generalization bound. With this realization in mind, in the last part of the thesis, we will change course and introduce an {\em empirical} technique to estimate generalization using unlabeled data. Our technique does not rely on any notion of uniform-convergece-based complexity and is remarkably precise. We will theoretically show why our technique enjoys such precision. We will conclude by discussing how future work could explore novel ways to incorporate distributional assumptions in generalization bounds (such as in the form of unlabeled data) and explore other tools to derive bounds, perhaps by modifying uniform convergence or by developing completely new tools altogether.

LGJun 25, 2021
Assessing Generalization of SGD via Disagreement

Yiding Jiang, Vaishnavh Nagarajan, Christina Baek et al.

We empirically show that the test error of deep networks can be estimated by simply training the same architecture on the same training set but with a different run of Stochastic Gradient Descent (SGD), and measuring the disagreement rate between the two networks on unlabeled test data. This builds on -- and is a stronger version of -- the observation in Nakkiran & Bansal '20, which requires the second run to be on an altogether fresh training set. We further theoretically show that this peculiar phenomenon arises from the \emph{well-calibrated} nature of \emph{ensembles} of SGD-trained models. This finding not only provides a simple empirical measure to directly predict the test error using unlabeled test data, but also establishes a new conceptual connection between generalization and calibration.

LGNov 2, 2020
A Learning Theoretic Perspective on Local Explainability

Jeffrey Li, Vaishnavh Nagarajan, Gregory Plumb et al.

In this paper, we explore connections between interpretable machine learning and learning theory through the lens of local approximation explanations. First, we tackle the traditional problem of performance generalization and bound the test-time accuracy of a model using a notion of how locally explainable it is. Second, we explore the novel problem of explanation generalization which is an important concern for a growing class of finite sample-based local approximation explanations. Finally, we validate our theoretical results empirically and show that they reflect what can be seen in practice.

LGOct 29, 2020
Understanding the Failure Modes of Out-of-Distribution Generalization

Vaishnavh Nagarajan, Anders Andreassen, Behnam Neyshabur

Empirical studies suggest that machine learning models often rely on features, such as the background, that may be spuriously correlated with the label only during training time, resulting in poor accuracy during test-time. In this work, we identify the fundamental factors that give rise to this behavior, by explaining why models fail this way {\em even} in easy-to-learn tasks where one would expect these models to succeed. In particular, through a theoretical study of gradient-descent-trained linear classifiers on some easy-to-learn tasks, we uncover two complementary failure modes. These modes arise from how spurious correlations induce two kinds of skews in the data: one geometric in nature, and another, statistical in nature. Finally, we construct natural modifications of image classification datasets to understand when these failure modes can arise in practice. We also design experiments to isolate the two failure modes when training modern neural networks on these datasets.

LGJul 7, 2020
Provably Safe PAC-MDP Exploration Using Analogies

Melrose Roderick, Vaishnavh Nagarajan, J. Zico Kolter

A key challenge in applying reinforcement learning to safety-critical domains is understanding how to balance exploration (needed to attain good performance on the task) with safety (needed to avoid catastrophic failure). Although a growing line of work in reinforcement learning has investigated this area of "safe exploration," most existing techniques either 1) do not guarantee safety during the actual exploration process; and/or 2) limit the problem to a priori known and/or deterministic transition dynamics with strong smoothness assumptions. Addressing this gap, we propose Analogous Safe-state Exploration (ASE), an algorithm for provably safe exploration in MDPs with unknown, stochastic dynamics. Our method exploits analogies between state-action pairs to safely learn a near-optimal policy in a PAC-MDP sense. Additionally, ASE also guides exploration towards the most task-relevant states, which empirically results in significant improvements in terms of sample efficiency, when compared to existing methods.

LGMay 30, 2019
Deterministic PAC-Bayesian generalization bounds for deep networks via generalizing noise-resilience

Vaishnavh Nagarajan, J. Zico Kolter

The ability of overparameterized deep networks to generalize well has been linked to the fact that stochastic gradient descent (SGD) finds solutions that lie in flat, wide minima in the training loss -- minima where the output of the network is resilient to small random noise added to its parameters. So far this observation has been used to provide generalization guarantees only for neural networks whose parameters are either \textit{stochastic} or \textit{compressed}. In this work, we present a general PAC-Bayesian framework that leverages this observation to provide a bound on the original network learned -- a network that is deterministic and uncompressed. What enables us to do this is a key novelty in our approach: our framework allows us to show that if on training data, the interactions between the weight matrices satisfy certain conditions that imply a wide training loss minimum, these conditions themselves {\em generalize} to the interactions between the matrices on test data, thereby implying a wide test loss minimum. We then apply our general framework in a setup where we assume that the pre-activation values of the network are not too small (although we assume this only on the training data). In this setup, we provide a generalization guarantee for the original (deterministic, uncompressed) network, that does not scale with product of the spectral norms of the weight matrices -- a guarantee that would not have been possible with prior approaches.

LGFeb 13, 2019
Uniform convergence may be unable to explain generalization in deep learning

Vaishnavh Nagarajan, J. Zico Kolter

Aimed at explaining the surprisingly good generalization behavior of overparameterized deep networks, recent works have developed a variety of generalization bounds for deep learning, all based on the fundamental learning-theoretic technique of uniform convergence. While it is well-known that many of these existing bounds are numerically large, through numerous experiments, we bring to light a more concerning aspect of these bounds: in practice, these bounds can {\em increase} with the training dataset size. Guided by our observations, we then present examples of overparameterized linear classifiers and neural networks trained by gradient descent (GD) where uniform convergence provably cannot "explain generalization" -- even if we take into account the implicit bias of GD {\em to the fullest extent possible}. More precisely, even if we consider only the set of classifiers output by GD, which have test errors less than some small $ε$ in our settings, we show that applying (two-sided) uniform convergence on this set of classifiers will yield only a vacuous generalization guarantee larger than $1-ε$. Through these findings, we cast doubt on the power of uniform convergence-based generalization bounds to provide a complete picture of why overparameterized deep networks generalize well.

LGJan 7, 2019
Generalization in Deep Networks: The Role of Distance from Initialization

Vaishnavh Nagarajan, J. Zico Kolter

Why does training deep neural networks using stochastic gradient descent (SGD) result in a generalization error that does not worsen with the number of parameters in the network? To answer this question, we advocate a notion of effective model capacity that is dependent on {\em a given random initialization of the network} and not just the training algorithm and the data distribution. We provide empirical evidences that demonstrate that the model capacity of SGD-trained deep networks is in fact restricted through implicit regularization of {\em the $\ell_2$ distance from the initialization}. We also provide theoretical arguments that further highlight the need for initialization-dependent notions of model capacity. We leave as open questions how and why distance from initialization is regularized, and whether it is sufficient to explain generalization.

MLJun 7, 2018
Revisiting Adversarial Risk

Arun Sai Suggala, Adarsh Prasad, Vaishnavh Nagarajan et al.

Recent works on adversarial perturbations show that there is an inherent trade-off between standard test accuracy and adversarial accuracy. Specifically, they show that no classifier can simultaneously be robust to adversarial perturbations and achieve high standard test accuracy. However, this is contrary to the standard notion that on tasks such as image classification, humans are robust classifiers with low error rate. In this work, we show that the main reason behind this confusion is the inexact definition of adversarial perturbation that is used in the literature. To fix this issue, we propose a slight, yet important modification to the existing definition of adversarial perturbation. Based on the modified definition, we show that there is no trade-off between adversarial and standard accuracies; there exist classifiers that are robust and achieve high standard accuracy. We further study several properties of this new definition of adversarial risk and its relation to the existing definition.

LGJun 30, 2017
Lifelong Learning in Costly Feature Spaces

Maria-Florina Balcan, Avrim Blum, Vaishnavh Nagarajan

An important long-term goal in machine learning systems is to build learning agents that, like humans, can learn many tasks over their lifetime, and moreover use information from these tasks to improve their ability to do so efficiently. In this work, our goal is to provide new theoretical insights into the potential of this paradigm. In particular, we propose a lifelong learning framework that adheres to a novel notion of resource efficiency that is critical in many real-world domains where feature evaluations are costly. That is, our learner aims to reuse information from previously learned related tasks to learn future tasks in a feature-efficient manner. Furthermore, we consider novel combinatorial ways in which learning tasks can relate. Specifically, we design lifelong learning algorithms for two structurally different and widely used families of target functions: decision trees/lists and monomials/polynomials. We also provide strong feature-efficiency guarantees for these algorithms; in fact, we show that in order to learn future targets, we need only slightly more feature evaluations per training example than what is needed to predict on an arbitrary example using those targets. We also provide algorithms with guarantees in an agnostic model where not all the targets are related to each other. Finally, we also provide lower bounds on the performance of a lifelong learner in these models, which are in fact tight under some conditions.

LGJun 13, 2017
Gradient descent GAN optimization is locally stable

Vaishnavh Nagarajan, J. Zico Kolter

Despite the growing prominence of generative adversarial networks (GANs), optimization in GANs is still a poorly understood topic. In this paper, we analyze the "gradient descent" form of GAN optimization i.e., the natural setting where we simultaneously take small gradient steps in both generator and discriminator parameters. We show that even though GAN optimization does not correspond to a convex-concave game (even for simple parameterizations), under proper conditions, equilibrium points of this optimization procedure are still \emph{locally asymptotically stable} for the traditional GAN formulation. On the other hand, we show that the recently proposed Wasserstein GAN can have non-convergent limit cycles near equilibrium. Motivated by this stability analysis, we propose an additional regularization term for gradient descent GAN updates, which \emph{is} able to guarantee local stability for both the WGAN and the traditional GAN, and also shows practical promise in speeding up convergence and addressing mode collapse.

DSNov 14, 2016
Learning-Theoretic Foundations of Algorithm Configuration for Combinatorial Partitioning Problems

Maria-Florina Balcan, Vaishnavh Nagarajan, Ellen Vitercik et al.

Max-cut, clustering, and many other partitioning problems that are of significant importance to machine learning and other scientific fields are NP-hard, a reality that has motivated researchers to develop a wealth of approximation algorithms and heuristics. Although the best algorithm to use typically depends on the specific application domain, a worst-case analysis is often used to compare algorithms. This may be misleading if worst-case instances occur infrequently, and thus there is a demand for optimization methods which return the algorithm configuration best suited for the given application's typical inputs. We address this problem for clustering, max-cut, and other partitioning problems, such as integer quadratic programming, by designing computationally efficient and sample efficient learning algorithms which receive samples from an application-specific distribution over problem instances and learn a partitioning algorithm with high expected performance. Our algorithms learn over common integer quadratic programming and clustering algorithm families: SDP rounding algorithms and agglomerative clustering algorithms with dynamic programming. For our sample complexity analysis, we provide tight bounds on the pseudodimension of these algorithm classes, and show that surprisingly, even for classes of algorithms parameterized by a single parameter, the pseudo-dimension is superconstant. In this way, our work both contributes to the foundations of algorithm configuration and pushes the boundaries of learning theory, since the algorithm classes we analyze consist of multi-stage optimization procedures and are significantly more complex than classes typically studied in learning theory.

LGJul 24, 2015
A Reinforcement Learning Approach to Online Learning of Decision Trees

Abhinav Garlapati, Aditi Raghunathan, Vaishnavh Nagarajan et al.

Online decision tree learning algorithms typically examine all features of a new data point to update model parameters. We propose a novel alternative, Reinforcement Learning- based Decision Trees (RLDT), that uses Reinforcement Learning (RL) to actively examine a minimal number of features of a data point to classify it with high accuracy. Furthermore, RLDT optimizes a long term return, providing a better alternative to the traditional myopic greedy approach to growing decision trees. We demonstrate that this approach performs as well as batch learning algorithms and other online decision tree learning algorithms, while making significantly fewer queries about the features of the data points. We also show that RLDT can effectively handle concept drift.