Andrew Saxe

LG
h-index8
19papers
713citations
Novelty56%
AI Score52

19 Papers

LGJun 16, 2022
Know your audience: specializing grounded language models with listener subtraction

Aaditya K. Singh, David Ding, Andrew Saxe et al. · deepmind, stanford

Effective communication requires adapting to the idiosyncrasies of each communicative context--such as the common ground shared with each partner. Humans demonstrate this ability to specialize to their audience in many contexts, such as the popular game Dixit. We take inspiration from Dixit to formulate a multi-agent image reference game where a (trained) speaker model is rewarded for describing a target image such that one (pretrained) listener model can correctly identify it among distractors, but another listener cannot. To adapt, the speaker must exploit differences in the knowledge it shares with the different listeners. We show that finetuning an attention-based adapter between a CLIP vision encoder and a large language model in this contrastive, multi-agent setting gives rise to context-dependent natural language specialization from rewards only, without direct supervision. Through controlled experiments, we show that training a speaker with two listeners that perceive differently, using our method, allows the speaker to adapt to the idiosyncracies of the listeners. Furthermore, we show zero-shot transfer of the specialization to real-world data. Our experiments demonstrate a method for specializing grounded language models without direct supervision and highlight the interesting research challenges posed by complex multi-agent communication.

NCMar 22, 2022
Modelling continual learning in humans with Hebbian context gating and exponentially decaying task signals

Timo Flesch, David G. Nagy, Andrew Saxe et al.

Humans can learn several tasks in succession with minimal mutual interference but perform more poorly when trained on multiple tasks at once. The opposite is true for standard deep neural networks. Here, we propose novel computational constraints for artificial neural networks, inspired by earlier work on gating in the primate prefrontal cortex, that capture the cost of interleaved training and allow the network to learn two tasks in sequence without forgetting. We augment standard stochastic gradient descent with two algorithmic motifs, so-called "sluggish" task units and a Hebbian training step that strengthens connections between task units and hidden units that encode task-relevant information. We found that the "sluggish" units introduce a switch-cost during training, which biases representations under interleaved training towards a joint representation that ignores the contextual cue, while the Hebbian step promotes the formation of a gating scheme from task units to the hidden layer that produces orthogonal representations which are perfectly guarded against interference. Validating the model on previously published human behavioural data revealed that it matches performance of participants who had been trained on blocked or interleaved curricula, and that these performance differences were driven by misestimation of the true category boundary.

NCOct 10, 2022
Continual task learning in natural and artificial agents

Timo Flesch, Andrew Saxe, Christopher Summerfield

How do humans and other animals learn new tasks? A wave of brain recording studies has investigated how neural representations change during task learning, with a focus on how tasks can be acquired and coded in ways that minimise mutual interference. We review recent work that has explored the geometry and dimensionality of neural task representations in neocortex, and computational models that have exploited these findings to understand how the brain may partition knowledge between tasks. We discuss how ideas from machine learning, including those that combine supervised and unsupervised learning, are helping neuroscientists understand how natural tasks are learned and coded in biological brains.

MLMay 18, 2022
Maslow's Hammer for Catastrophic Forgetting: Node Re-Use vs Node Activation

Sebastian Lee, Stefano Sarao Mannelli, Claudia Clopath et al.

Continual learning - learning new tasks in sequence while maintaining performance on old tasks - remains particularly challenging for artificial neural networks. Surprisingly, the amount of forgetting does not increase with the dissimilarity between the learned tasks, but appears to be worst in an intermediate similarity regime. In this paper we theoretically analyse both a synthetic teacher-student framework and a real data setup to provide an explanation of this phenomenon that we name Maslow's hammer hypothesis. Our analysis reveals the presence of a trade-off between node activation and node re-use that results in worst forgetting in the intermediate regime. Using this understanding we reinterpret popular algorithmic interventions for catastrophic interference in terms of this trade-off, and identify the regimes in which they are most effective.

LGJun 17, 2023
The RL Perceptron: Generalisation Dynamics of Policy Learning in High Dimensions

Nishil Patel, Sebastian Lee, Stefano Sarao Mannelli et al.

Reinforcement learning (RL) algorithms have proven transformative in a range of domains. To tackle real-world domains, these systems often use neural networks to learn policies directly from pixels or other high-dimensional sensory input. By contrast, much theory of RL has focused on discrete state spaces or worst-case analysis, and fundamental questions remain about the dynamics of policy learning in high-dimensional settings. Here, we propose a solvable high-dimensional model of RL that can capture a variety of learning protocols, and derive its typical dynamics as a set of closed-form ordinary differential equations (ODEs). We derive optimal schedules for the learning rates and task difficulty - analogous to annealing schemes and curricula during training in RL - and show that the model exhibits rich behaviour, including delayed learning under sparse rewards; a variety of learning regimes depending on reward baselines; and a speed-accuracy trade-off driven by reward stringency. Experiments on variants of the Procgen game "Bossfight" and Arcade Learning Environment game "Pong" also show such a speed-accuracy trade-off in practice. Together, these results take a step towards closing the gap between theory and practice in high-dimensional RL.

LGMay 19
Optimal Representation Size: High-Dimensional Analysis of Pretraining and Linear Probing

Valentina Njaradi, Clémentine Dominé, Rachel Swanson et al.

Learning to generalise from limited data is a fundamental challenge for both artificial and biological systems. A common strategy is to extract reusable structure from abundant unlabelled data, enabling efficient adaptation to new tasks from limited labelled data. This two-stage paradigm is now standard in modern training pipelines, where pretraining is followed by fine-tuning or linear probing. We provide an analytical model of this process: structure extraction is formalized as principal component analysis on unlabelled data, and downstream learning as linear regression on a separate labelled dataset. In the high-dimensional regime, we derive exact expressions for training and generalisation error showcasing their dependence on representation dimensionality, unlabelled and labelled sample sizes, and task alignment. Our results show that pretrained representations strongly influence downstream generalisation, and we characterize the optimal representation size as a function of task parameters: with abundant pretraining data but scarce downstream data, maximally compressed representations are optimal, whereas with limited pretraining data, higher-dimensional representations generalise better. Furthermore, we establish an exact trade-off between pretraining and supervision, quantifying how much unlabelled data is required to replace a single labelled sample. Beyond our idealised model, we observe similar phenomenology in autoencoders and pretrained LLMs. Altogether, we highlight that optimising representation size is critical, giving conditions for when compression during pretraining improves generalisation.

LGDec 23, 2025
Saddle-to-Saddle Dynamics Explains A Simplicity Bias Across Neural Network Architectures

Yedi Zhang, Andrew Saxe, Peter E. Latham

Neural networks trained with gradient descent often learn solutions of increasing complexity over time, a phenomenon known as simplicity bias. Despite being widely observed across architectures, existing theoretical treatments lack a unifying framework. We present a theoretical framework that explains a simplicity bias arising from saddle-to-saddle learning dynamics for a general class of neural networks, incorporating fully-connected, convolutional, and attention-based architectures. Here, simple means expressible with few hidden units, i.e., hidden neurons, convolutional kernels, or attention heads. Specifically, we show that linear networks learn solutions of increasing rank, ReLU networks learn solutions with an increasing number of kinks, convolutional networks learn solutions with an increasing number of convolutional kernels, and self-attention models learn solutions with an increasing number of attention heads. By analyzing fixed points, invariant manifolds, and dynamics of gradient descent learning, we show that saddle-to-saddle dynamics operates by iteratively evolving near an invariant manifold, approaching a saddle, and switching to another invariant manifold. Our analysis also illuminates the effects of data distribution and weight initialization on the duration and number of plateaus in learning, dissociating previously confounding factors. Overall, our theory offers a framework for understanding when and why gradient descent progressively learns increasingly complex solutions.

LGJan 12
Optimal Learning Rate Schedule for Balancing Effort and Performance

Valentina Njaradi, Rodrigo Carrasco-Davis, Peter E. Latham et al.

Learning how to learn efficiently is a fundamental challenge for biological agents and a growing concern for artificial ones. To learn effectively, an agent must regulate its learning speed, balancing the benefits of rapid improvement against the costs of effort, instability, or resource use. We introduce a normative framework that formalizes this problem as an optimal control process in which the agent maximizes cumulative performance while incurring a cost of learning. From this objective, we derive a closed-form solution for the optimal learning rate, which has the form of a closed-loop controller that depends only on the agent's current and expected future performance. Under mild assumptions, this solution generalizes across tasks and architectures and reproduces numerically optimized schedules in simulations. In simple learning models, we can mathematically analyze how agent and task parameters shape learning-rate scheduling as an open-loop control solution. Because the optimal policy depends on expectations of future performance, the framework predicts how overconfidence or underconfidence influence engagement and persistence, linking the control of learning speed to theories of self-regulated learning. We further show how a simple episodic memory mechanism can approximate the required performance expectations by recalling similar past learning experiences, providing a biologically plausible route to near-optimal behaviour. Together, these results provide a normative and biologically plausible account of learning speed control, linking self-regulated learning, effort allocation, and episodic memory estimation within a unified and tractable mathematical framework.

LGJan 27, 2025
Training Dynamics of In-Context Learning in Linear Attention

Yedi Zhang, Aaditya K. Singh, Peter E. Latham et al.

While attention-based models have demonstrated the remarkable ability of in-context learning (ICL), the theoretical understanding of how these models acquired this ability through gradient descent training is still preliminary. Towards answering this question, we study the gradient descent dynamics of multi-head linear self-attention trained for in-context linear regression. We examine two parametrizations of linear self-attention: one with the key and query weights merged as a single matrix (common in theoretical studies), and one with separate key and query matrices (closer to practical settings). For the merged parametrization, we show that the training dynamics has two fixed points and the loss trajectory exhibits a single, abrupt drop. We derive an analytical time-course solution for a certain class of datasets and initialization. For the separate parametrization, we show that the training dynamics has exponentially many fixed points and the loss exhibits saddle-to-saddle dynamics, which we reduce to scalar ordinary differential equations. During training, the model implements principal component regression in context with the number of principal components increasing over training time. Overall, we provide a theoretical description of how ICL abilities evolve during gradient descent training of linear attention, revealing abrupt acquisition or progressive improvements depending on how the key and query are parametrized.

LGMar 5, 2025
Revisiting the Role of Relearning in Semantic Dementia

Devon Jarvis, Verena Klar, Richard Klein et al.

Patients with semantic dementia (SD) present with remarkably consistent atrophy of neurons in the anterior temporal lobe and behavioural impairments, such as graded loss of category knowledge. While relearning of lost knowledge has been shown in acute brain injuries such as stroke, it has not been widely supported in chronic cognitive diseases such as SD. Previous research has shown that deep linear artificial neural networks exhibit stages of semantic learning akin to humans. Here, we use a deep linear network to test the hypothesis that relearning during disease progression rather than particular atrophy cause the specific behavioural patterns associated with SD. After training the network to generate the common semantic features of various hierarchically organised objects, neurons are successively deleted to mimic atrophy while retraining the model. The model with relearning and deleted neurons reproduced errors specific to SD, including prototyping errors and cross-category confusions. This suggests that relearning is necessary for artificial neural networks to reproduce the behavioural patterns associated with SD in the absence of \textit{output} non-linearities. Our results support a theory of SD progression that results from continuous relearning of lost information. Future research should revisit the role of relearning as a contributing factor to cognitive diseases.

LGJun 25, 2024
Early learning of the optimal constant solution in neural networks and humans

Jirko Rubruck, Jan P. Bauer, Andrew Saxe et al.

Deep neural networks learn increasingly complex functions over the course of training. Here, we show both empirically and theoretically that learning of the target function is preceded by an early phase in which networks learn the optimal constant solution (OCS) - that is, initial model responses mirror the distribution of target labels, while entirely ignoring information provided in the input. Using a hierarchical category learning task, we derive exact solutions for learning dynamics in deep linear networks trained with bias terms. Even when initialized to zero, this simple architectural feature induces substantial changes in early dynamics. We identify hallmarks of this early OCS phase and illustrate how these signatures are observed in deep linear networks and larger, more complex (and nonlinear) convolutional neural networks solving a hierarchical learning task based on MNIST and CIFAR10. We explain these observations by proving that deep linear networks necessarily learn the OCS during early learning. To further probe the generality of our results, we train human learners over the course of three days on the category learning task. We then identify qualitative signatures of this early OCS phase in terms of the dynamics of true negative (correct-rejection) rates. Surprisingly, we find the same early reliance on the OCS in the behaviour of human learners. Finally, we show that learning of the OCS can emerge even in the absence of bias terms and is equivalently driven by generic correlations in the input data. Overall, our work suggests the OCS as a universal learning principle in supervised, error-corrective learning, and the mechanistic reasons for its prevalence.

LGJun 18, 2024
When Are Bias-Free ReLU Networks Effectively Linear Networks?

Yedi Zhang, Andrew Saxe, Peter E. Latham

We investigate the implications of removing bias in ReLU networks regarding their expressivity and learning dynamics. We first show that two-layer bias-free ReLU networks have limited expressivity: the only odd function two-layer bias-free ReLU networks can express is a linear one. We then show that, under symmetry conditions on the data, these networks have the same learning dynamics as linear networks. This enables us to give analytical time-course solutions to certain two-layer bias-free (leaky) ReLU networks outside the lazy learning regime. While deep bias-free ReLU networks are more expressive than their two-layer counterparts, they still share a number of similarities with deep linear networks. These similarities enable us to leverage insights from linear networks to understand certain ReLU networks. Overall, our results show that some properties previously established for bias-free ReLU networks arise due to equivalence to linear networks.

LGJun 10, 2024
Get rich quick: exact solutions reveal how unbalanced initializations promote rapid feature learning

Daniel Kunin, Allan Raventós, Clémentine Dominé et al.

While the impressive performance of modern neural networks is often attributed to their capacity to efficiently extract task-relevant features from data, the mechanisms underlying this rich feature learning regime remain elusive, with much of our theoretical understanding stemming from the opposing lazy regime. In this work, we derive exact solutions to a minimal model that transitions between lazy and rich learning, precisely elucidating how unbalanced layer-specific initialization variances and learning rates determine the degree of feature learning. Our analysis reveals that they conspire to influence the learning regime through a set of conserved quantities that constrain and modify the geometry of learning trajectories in parameter and function space. We extend our analysis to more complex linear models with multiple neurons, outputs, and layers and to shallow nonlinear networks with piecewise linear activation functions. In linear networks, rapid feature learning only occurs from balanced initializations, where all layers learn at similar speeds. While in nonlinear networks, unbalanced initializations that promote faster learning in earlier layers can accelerate rich learning. Through a series of experiments, we provide evidence that this unbalanced rich regime drives feature learning in deep finite-width networks, promotes interpretability of early layers in CNNs, reduces the sample complexity of learning hierarchical data, and decreases the time to grokking in modular arithmetic. Our theory motivates further exploration of unbalanced initializations to enhance efficient feature learning.

MLJun 3, 2024
Tilting the Odds at the Lottery: the Interplay of Overparameterisation and Curricula in Neural Networks

Stefano Sarao Mannelli, Yaraslau Ivashynka, Andrew Saxe et al.

A wide range of empirical and theoretical works have shown that overparameterisation can amplify the performance of neural networks. According to the lottery ticket hypothesis, overparameterised networks have an increased chance of containing a sub-network that is well-initialised to solve the task at hand. A more parsimonious approach, inspired by animal learning, consists in guiding the learner towards solving the task by curating the order of the examples, i.e. providing a curriculum. However, this learning strategy seems to be hardly beneficial in deep learning applications. In this work, we undertake an analytical study that connects curriculum learning and overparameterisation. In particular, we investigate their interplay in the online learning setting for a 2-layer network in the XOR-like Gaussian Mixture problem. Our results show that a high degree of overparameterisation -- while simplifying the problem -- can limit the benefit from curricula, providing a theoretical account of the ineffectiveness of curricula in deep learning.

MLJul 9, 2021
Continual Learning in the Teacher-Student Setup: Impact of Task Similarity

Sebastian Lee, Sebastian Goldt, Andrew Saxe

Continual learning-the ability to learn many tasks in sequence-is critical for artificial learning systems. Yet standard training methods for deep networks often suffer from catastrophic forgetting, where learning new tasks erases knowledge of earlier tasks. While catastrophic forgetting labels the problem, the theoretical reasons for interference between tasks remain unclear. Here, we attempt to narrow this gap between theory and practice by studying continual learning in the teacher-student setup. We extend previous analytical work on two-layer networks in the teacher-student setup to multiple teachers. Using each teacher to represent a different task, we investigate how the relationship between teachers affects the amount of forgetting and transfer exhibited by the student when the task switches. In line with recent work, we find that when tasks depend on similar features, intermediate task similarity leads to greatest forgetting. However, feature similarity is only one way in which tasks may be related. The teacher-student approach allows us to disentangle task similarity at the level of readouts (hidden-to-output weights) and features (input-to-hidden weights). We find a complex interplay between both types of similarity, initial transfer/forgetting rates, maximum transfer/forgetting, and long-term transfer/forgetting. Together, these results help illuminate the diverse factors contributing to catastrophic forgetting.

LGJun 15, 2021
An Analytical Theory of Curriculum Learning in Teacher-Student Networks

Luca Saglietti, Stefano Sarao Mannelli, Andrew Saxe

In humans and animals, curriculum learning -- presenting data in a curated order - is critical to rapid learning and effective pedagogy. Yet in machine learning, curricula are not widely used and empirically often yield only moderate benefits. This stark difference in the importance of curriculum raises a fundamental theoretical question: when and why does curriculum learning help? In this work, we analyse a prototypical neural network model of curriculum learning in the high-dimensional limit, employing statistical physics methods. Curricula could in principle change both the learning speed and asymptotic performance of a model. To study the former, we provide an exact description of the online learning setting, confirming the long-standing experimental observation that curricula can modestly speed up learning. To study the latter, we derive performance in a batch learning setting, in which a network trains to convergence in successive phases of learning on dataset slices of varying difficulty. With standard training losses, curriculum does not provide generalisation benefit, in line with empirical observations. However, we show that by connecting different learning phases through simple Gaussian priors, curriculum can yield a large improvement in test performance. Taken together, our reduced analytical descriptions help reconcile apparently conflicting empirical results and trace regimes where curriculum learning yields the largest gains. More broadly, our results suggest that fully exploiting a curriculum may require explicit changes to the loss function at curriculum boundaries.

LGJun 9, 2021
Probing transfer learning with a model of synthetic correlated datasets

Federica Gerace, Luca Saglietti, Stefano Sarao Mannelli et al.

Transfer learning can significantly improve the sample efficiency of neural networks, by exploiting the relatedness between a data-scarce target task and a data-abundant source task. Despite years of successful applications, transfer learning practice often relies on ad-hoc solutions, while theoretical understanding of these procedures is still limited. In the present work, we re-think a solvable model of synthetic data as a framework for modeling correlation between data-sets. This setup allows for an analytic characterization of the generalization performance obtained when transferring the learned feature map from the source to the target task. Focusing on the problem of training two-layer networks in a binary classification setting, we show that our model can capture a range of salient features of transfer learning with real data. Moreover, by exploiting parametric control over the correlation between the two data-sets, we systematically investigate under which conditions the transfer of features is beneficial for generalization.

LGJul 17, 2018
Are Efficient Deep Representations Learnable?

Maxwell Nye, Andrew Saxe

Many theories of deep learning have shown that a deep network can require dramatically fewer resources to represent a given function compared to a shallow network. But a question remains: can these efficient representations be learned using current deep learning techniques? In this work, we test whether standard deep learning methods can in fact find the efficient representations posited by several theories of deep representation. Specifically, we train deep neural networks to learn two simple functions with known efficient solutions: the parity function and the fast Fourier transform. We find that using gradient-based optimization, a deep network does not learn the parity function, unless initialized very close to a hand-coded exact solution. We also find that a deep linear neural network does not learn the fast Fourier transform, even in the best-case scenario of infinite training data, unless the weights are initialized very close to the exact hand-coded solution. Our results suggest that not every element of the class of compositional functions can be learned efficiently by a deep network, and further restrictions are necessary to understand what functions are both efficiently representable and learnable.

NEOct 31, 2016
Tensor Switching Networks

Chuan-Yung Tsai, Andrew Saxe, David Cox

We present a novel neural network algorithm, the Tensor Switching (TS) network, which generalizes the Rectified Linear Unit (ReLU) nonlinearity to tensor-valued hidden units. The TS network copies its entire input vector to different locations in an expanded representation, with the location determined by its hidden unit activity. In this way, even a simple linear readout from the TS representation can implement a highly expressive deep-network-like function. The TS network hence avoids the vanishing gradient problem by construction, at the cost of larger representation size. We develop several methods to train the TS network, including equivalent kernels for infinitely wide and deep TS networks, a one-pass linear learning algorithm, and two backpropagation-inspired representation learning algorithms. Our experimental results demonstrate that the TS network is indeed more expressive and consistently learns faster than standard ReLU networks.