LGSep 26, 2024Code
DMC-VB: A Benchmark for Representation Learning for Control with Visual DistractorsJoseph Ortiz, Antoine Dedieu, Wolfgang Lehrach et al. · deepmind
Learning from previously collected data via behavioral cloning or offline reinforcement learning (RL) is a powerful recipe for scaling generalist agents by avoiding the need for expensive online learning. Despite strong generalization in some respects, agents are often remarkably brittle to minor visual variations in control-irrelevant factors such as the background or camera viewpoint. In this paper, we present theDeepMind Control Visual Benchmark (DMC-VB), a dataset collected in the DeepMind Control Suite to evaluate the robustness of offline RL agents for solving continuous control tasks from visual input in the presence of visual distractors. In contrast to prior works, our dataset (a) combines locomotion and navigation tasks of varying difficulties, (b) includes static and dynamic visual variations, (c) considers data generated by policies with different skill levels, (d) systematically returns pairs of state and pixel observation, (e) is an order of magnitude larger, and (f) includes tasks with hidden goals. Accompanying our dataset, we propose three benchmarks to evaluate representation learning methods for pretraining, and carry out experiments on several recently proposed methods. First, we find that pretrained representations do not help policy learning on DMC-VB, and we highlight a large representation gap between policies learned on pixel observations and on states. Second, we demonstrate when expert data is limited, policy learning can benefit from representations pretrained on (a) suboptimal data, and (b) tasks with stochastic hidden goals. Our dataset and benchmark code to train and evaluate agents are available at: https://github.com/google-deepmind/dmc_vision_benchmark.
LGJan 31, 2023
Learning noisy-OR Bayesian Networks with Max-Product Belief PropagationAntoine Dedieu, Guangyao Zhou, Dileep George et al. · deepmind
Noisy-OR Bayesian Networks (BNs) are a family of probabilistic graphical models which express rich statistical dependencies in binary data. Variational inference (VI) has been the main method proposed to learn noisy-OR BNs with complex latent structures (Jaakkola & Jordan, 1999; Ji et al., 2020; Buhai et al., 2020). However, the proposed VI approaches either (a) use a recognition network with standard amortized inference that cannot induce ``explaining-away''; or (b) assume a simple mean-field (MF) posterior which is vulnerable to bad local optima. Existing MF VI methods also update the MF parameters sequentially which makes them inherently slow. In this paper, we propose parallel max-product as an alternative algorithm for learning noisy-OR BNs with complex latent structures and we derive a fast stochastic training scheme that scales to large datasets. We evaluate both approaches on several benchmarks where VI is the state-of-the-art and show that our method (a) achieves better test performance than Ji et al. (2020) for learning noisy-OR BNs with hierarchical latent structures on large sparse real datasets; (b) recovers a higher number of ground truth parameters than Buhai et al. (2020) from cluttered synthetic scenes; and (c) solves the 2D blind deconvolution problem from Lazaro-Gredilla et al. (2021) and variant - including binary matrix factorization - while VI catastrophically fails and is up to two orders of magnitude slower.
CLJun 16, 2023
Schema-learning and rebinding as mechanisms of in-context learning and emergenceSivaramakrishnan Swaminathan, Antoine Dedieu, Rajkumar Vasudeva Raju et al.
In-context learning (ICL) is one of the most powerful and most unexpected capabilities to emerge in recent transformer-based large language models (LLMs). Yet the mechanisms that underlie it are poorly understood. In this paper, we demonstrate that comparable ICL capabilities can be acquired by an alternative sequence prediction learning method using clone-structured causal graphs (CSCGs). Moreover, a key property of CSCGs is that, unlike transformer-based LLMs, they are {\em interpretable}, which considerably simplifies the task of explaining how ICL works. Specifically, we show that it uses a combination of (a) learning template (schema) circuits for pattern completion, (b) retrieving relevant templates in a context-sensitive manner, and (c) rebinding of novel tokens to appropriate slots in the templates. We go on to marshall evidence for the hypothesis that similar mechanisms underlie ICL in LLMs. For example, we find that, with CSCGs as with LLMs, different capabilities emerge at different levels of overparameterization, suggesting that overparameterization helps in learning more complex template (schema) circuits. By showing how ICL can be achieved with small models and datasets, we open up a path to novel architectures, and take a vital step towards a more general understanding of the mechanics behind this important capability.
LGFeb 8, 2022Code
PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAXGuangyao Zhou, Antoine Dedieu, Nishanth Kumar et al.
PGMax is an open-source Python package for (a) easily specifying discrete Probabilistic Graphical Models (PGMs) as factor graphs; and (b) automatically running efficient and scalable loopy belief propagation (LBP) in JAX. PGMax supports general factor graphs with tractable factors, and leverages modern accelerators like GPUs for inference. Compared with existing alternatives, PGMax obtains higher-quality inference results with up to three orders-of-magnitude inference time speedups. PGMax additionally interacts seamlessly with the rapidly growing JAX ecosystem, opening up new research possibilities. Our source code, examples and documentation are available at https://github.com/deepmind/PGMax.
LGFeb 3, 2025
Improving Transformer World Models for Data-Efficient RLAntoine Dedieu, Joseph Ortiz, Xinghua Lou et al.
We present three improvements to the standard model-based RL paradigm based on transformers: (a) "Dyna with warmup", which trains the policy on real and imaginary data, but only starts using imaginary data after the world model has been sufficiently trained; (b) "nearest neighbor tokenizer" for image patches, which improves upon previous tokenization schemes, which are needed when using a transformer world model (TWM), by ensuring the code words are static after creation, thus providing a constant target for TWM learning; and (c) "block teacher forcing", which allows the TWM to reason jointly about the future tokens of the next timestep, instead of generating them sequentially. We then show that our method significantly improves upon prior methods in various environments. We mostly focus on the challenging Craftax-classic benchmark, where our method achieves a reward of 69.66% after only 1M environment steps, significantly outperforming DreamerV3, which achieves 53.2%, and exceeding human performance of 65.0% for the first time. We also show preliminary results on Craftax-full, MinAtar, and three different two-player games, to illustrate the generality of the approach.
LGJan 11, 2024
Learning Cognitive Maps from Transformer Representations for Efficient Planning in Partially Observed EnvironmentsAntoine Dedieu, Wolfgang Lehrach, Guangyao Zhou et al. · deepmind
Despite their stellar performance on a wide range of tasks, including in-context tasks only revealed during inference, vanilla transformers and variants trained for next-token predictions (a) do not learn an explicit world model of their environment which can be flexibly queried and (b) cannot be used for planning or navigation. In this paper, we consider partially observed environments (POEs), where an agent receives perceptually aliased observations as it navigates, which makes path planning hard. We introduce a transformer with (multiple) discrete bottleneck(s), TDB, whose latent codes learn a compressed representation of the history of observations and actions. After training a TDB to predict the future observation(s) given the history, we extract interpretable cognitive maps of the environment from its active bottleneck(s) indices. These maps are then paired with an external solver to solve (constrained) path planning problems. First, we show that a TDB trained on POEs (a) retains the near perfect predictive performance of a vanilla transformer or an LSTM while (b) solving shortest path problems exponentially faster. Second, a TDB extracts interpretable representations from text datasets, while reaching higher in-context accuracy than vanilla sequence models. Finally, in new POEs, a TDB (a) reaches near-perfect in-context accuracy, (b) learns accurate in-context cognitive maps (c) solves in-context path planning problems.
AIOct 6, 2025
Code World Models for General Game PlayingWolfgang Lehrach, Daniel Hennes, Miguel Lazaro-Gredilla et al.
Large Language Models (LLMs) reasoning abilities are increasingly being applied to classical board and card games, but the dominant approach -- involving prompting for direct move generation -- has significant drawbacks. It relies on the model's implicit fragile pattern-matching capabilities, leading to frequent illegal moves and strategically shallow play. Here we introduce an alternative approach: We use the LLM to translate natural language rules and game trajectories into a formal, executable world model represented as Python code. This generated model -- comprising functions for state transition, legal move enumeration, and termination checks -- serves as a verifiable simulation engine for high-performance planning algorithms like Monte Carlo tree search (MCTS). In addition, we prompt the LLM to generate heuristic value functions (to make MCTS more efficient), and inference functions (to estimate hidden states in imperfect information games). Our method offers three distinct advantages compared to directly using the LLM as a policy: (1) Verifiability: The generated CWM serves as a formal specification of the game's rules, allowing planners to algorithmically enumerate valid actions and avoid illegal moves, contingent on the correctness of the synthesized model; (2) Strategic Depth: We combine LLM semantic understanding with the deep search power of classical planners; and (3) Generalization: We direct the LLM to focus on the meta-task of data-to-code translation, enabling it to adapt to new games more easily. We evaluate our agent on 10 different games, of which 4 are novel and created for this paper. 5 of the games are fully observed (perfect information), and 5 are partially observed (imperfect information). We find that our method outperforms or matches Gemini 2.5 Pro in 9 out of the 10 considered games.
LGDec 6, 2021
Graphical Models with Attention for Context-Specific Independence and an Application to Perceptual GroupingGuangyao Zhou, Wolfgang Lehrach, Antoine Dedieu et al.
Discrete undirected graphical models, also known as Markov Random Fields (MRFs), can flexibly encode probabilistic interactions of multiple variables, and have enjoyed successful applications to a wide range of problems. However, a well-known yet little studied limitation of discrete MRFs is that they cannot capture context-specific independence (CSI). Existing methods require carefully developed theories and purpose-built inference methods, which limit their applications to only small-scale problems. In this paper, we propose the Markov Attention Model (MAM), a family of discrete MRFs that incorporates an attention mechanism. The attention mechanism allows variables to dynamically attend to some other variables while ignoring the rest, and enables capturing of CSIs in MRFs. A MAM is formulated as an MRF, allowing it to benefit from the rich set of existing MRF inference methods and scale to large models and datasets. To demonstrate MAM's capabilities to capture CSIs at scale, we apply MAMs to capture an important type of CSI that is present in a symbolic approach to recurrent computations in perceptual grouping. Experiments on two recently proposed synthetic perceptual grouping tasks and on realistic images demonstrate the advantages of MAMs in sample-efficiency, interpretability and generalizability when compared with strong recurrent neural network baselines, and validate MAM's capabilities to efficiently capture CSIs at scale.
MLNov 3, 2021
Perturb-and-max-product: Sampling and learning in discrete energy-based modelsMiguel Lazaro-Gredilla, Antoine Dedieu, Dileep George
Perturb-and-MAP offers an elegant approach to approximately sample from a energy-based model (EBM) by computing the maximum-a-posteriori (MAP) configuration of a perturbed version of the model. Sampling in turn enables learning. However, this line of research has been hindered by the general intractability of the MAP computation. Very few works venture outside tractable models, and when they do, they use linear programming approaches, which as we will show, have several limitations. In this work we present perturb-and-max-product (PMP), a parallel and scalable mechanism for sampling and learning in discrete EBMs. Models can be arbitrary as long as they are built using tractable factors. We show that (a) for Ising models, PMP is orders of magnitude faster than Gibbs and Gibbs-with-Gradients (GWG) at learning and generating samples of similar or better quality; (b) PMP is able to learn and sample from RBMs; (c) in a large, entangled graphical model in which Gibbs and GWG fail to mix, PMP succeeds.
MLDec 3, 2020
Sample-Efficient L0-L2 Constrained Structure Learning of Sparse Ising ModelsAntoine Dedieu, Miguel Lázaro-Gredilla, Dileep George
We consider the problem of learning the underlying graph of a sparse Ising model with $p$ nodes from $n$ i.i.d. samples. The most recent and best performing approaches combine an empirical loss (the logistic regression loss or the interaction screening loss) with a regularizer (an L1 penalty or an L1 constraint). This results in a convex problem that can be solved separately for each node of the graph. In this work, we leverage the cardinality constraint L0 norm, which is known to properly induce sparsity, and further combine it with an L2 norm to better model the non-zero coefficients. We show that our proposed estimators achieve an improved sample complexity, both (a) theoretically, by reaching new state-of-the-art upper bounds for recovery guarantees, and (b) empirically, by showing sharper phase transitions between poor and full recovery for graph topologies studied in the literature, when compared to their L1-based state-of-the-art methods.
MLJun 11, 2020
Query Training: Learning a Worse Model to Infer Better Marginals in Undirected Graphical Models with Hidden VariablesMiguel Lázaro-Gredilla, Wolfgang Lehrach, Nishad Gothoskar et al.
Probabilistic graphical models (PGMs) provide a compact representation of knowledge that can be queried in a flexible way: after learning the parameters of a graphical model once, new probabilistic queries can be answered at test time without retraining. However, when using undirected PGMS with hidden variables, two sources of error typically compound in all but the simplest models (a) learning error (both computing the partition function and integrating out the hidden variables is intractable); and (b) prediction error (exact inference is also intractable). Here we introduce query training (QT), a mechanism to learn a PGM that is optimized for the approximate inference algorithm that will be paired with it. The resulting PGM is a worse model of the data (as measured by the likelihood), but it is tuned to produce better marginals for a given inference algorithm. Unlike prior works, our approach preserves the querying flexibility of the original PGM: at test time, we can estimate the marginal of any variable given any partial evidence. We demonstrate experimentally that QT can be used to learn a challenging 8-connected grid Markov random field with hidden variables and that it consistently outperforms the state-of-the-art AdVIL when tested on three undirected models across multiple datasets.
MLJan 17, 2020
Learning Sparse Classifiers: Continuous and Mixed Integer Optimization PerspectivesAntoine Dedieu, Hussein Hazimeh, Rahul Mazumder
We consider a discrete optimization formulation for learning sparse classifiers, where the outcome depends upon a linear combination of a small subset of features. Recent work has shown that mixed integer programming (MIP) can be used to solve (to optimality) $\ell_0$-regularized regression problems at scales much larger than what was conventionally considered possible. Despite their usefulness, MIP-based global optimization approaches are significantly slower compared to the relatively mature algorithms for $\ell_1$-regularization and heuristics for nonconvex regularized problems. We aim to bridge this gap in computation times by developing new MIP-based algorithms for $\ell_0$-regularized classification. We propose two classes of scalable algorithms: an exact algorithm that can handle $p\approx 50,000$ features in a few minutes, and approximate algorithms that can address instances with $p\approx 10^6$ in times comparable to the fast $\ell_1$-based algorithms. Our exact algorithm is based on the novel idea of \textsl{integrality generation}, which solves the original problem (with $p$ binary variables) via a sequence of mixed integer programs that involve a small number of binary variables. Our approximate algorithms are based on coordinate descent and local combinatorial search. In addition, we present new estimation error bounds for a class of $\ell_0$-regularized estimators. Experiments on real and synthetic data demonstrate that our approach leads to models with considerably improved statistical performance (especially, variable selection) when compared to competing methods.
MLDec 21, 2019
An error bound for Lasso and Group Lasso in high dimensionsAntoine Dedieu
We leverage recent advances in high-dimensional statistics to derive new L2 estimation upper bounds for Lasso and Group Lasso in high-dimensions. For Lasso, our bounds scale as $(k^*/n) \log(p/k^*)$---$n\times p$ is the size of the design matrix and $k^*$ the dimension of the ground truth $\boldsymbolβ^*$---and match the optimal minimax rate. For Group Lasso, our bounds scale as $(s^*/n) \log\left( G / s^* \right) + m^* / n$---$G$ is the total number of groups and $m^*$ the number of coefficients in the $s^*$ groups which contain $\boldsymbolβ^*$---and improve over existing results. We additionally show that when the signal is strongly group-sparse, Group Lasso is superior to Lasso.
MLOct 20, 2019
Improved error rates for sparse (group) learning with Lipschitz loss functionsAntoine Dedieu
We study a family of sparse estimators defined as minimizers of some empirical Lipschitz loss function -- which include the hinge loss, the logistic loss and the quantile regression loss -- with a convex, sparse or group-sparse regularization. In particular, we consider the L1 norm on the coefficients, its sorted Slope version, and the Group L1-L2 extension. We propose a new theoretical framework that uses common assumptions in the literature to simultaneously derive new high-dimensional L2 estimation upper bounds for all three regularization schemes. %, and to improve over existing results. For L1 and Slope regularizations, our bounds scale as $(k^*/n) \log(p/k^*)$ -- $n\times p$ is the size of the design matrix and $k^*$ the dimension of the theoretical loss minimizer $\Bβ^*$ -- and match the optimal minimax rate achieved for the least-squares case. For Group L1-L2 regularization, our bounds scale as $(s^*/n) \log\left( G / s^* \right) + m^* / n$ -- $G$ is the total number of groups and $m^*$ the number of coefficients in the $s^*$ groups which contain $\Bβ^*$ -- and improve over the least-squares case. We show that, when the signal is strongly group-sparse, Group L1-L2 is superior to L1 and Slope. In addition, we adapt our approach to the sub-Gaussian linear regression framework and reach the optimal minimax rate for Lasso, and an improved rate for Group-Lasso. Finally, we release an accelerated proximal algorithm that computes the nine main convex estimators of interest when the number of variables is of the order of $100,000s$.
MLMay 1, 2019
Learning higher-order sequential structure with cloned HMMsAntoine Dedieu, Nishad Gothoskar, Scott Swingle et al.
Variable order sequence modeling is an important problem in artificial and natural intelligence. While overcomplete Hidden Markov Models (HMMs), in theory, have the capacity to represent long-term temporal structure, they often fail to learn and converge to local minima. We show that by constraining HMMs with a simple sparsity structure inspired by biology, we can make it learn variable order sequences efficiently. We call this model cloned HMM (CHMM) because the sparsity structure enforces that many hidden states map deterministically to the same emission state. CHMMs with over 1 billion parameters can be efficiently trained on GPUs without being severely affected by the credit diffusion problem of standard HMMs. Unlike n-grams and sequence memoizers, CHMMs can model temporal dependencies at arbitrarily long distances and recognize contexts with 'holes' in them. Compared to Recurrent Neural Networks and their Long Short-Term Memory extensions (LSTMs), CHMMs are generative models that can natively deal with uncertainty. Moreover, CHMMs return a higher-order graph that represents the temporal structure of the data which can be useful for community detection, and for building hierarchical models. Our experiments show that CHMMs can beat n-grams, sequence memoizers, and LSTMs on character-level language modeling tasks. CHMMs can be a viable alternative to these methods in some tasks that require variable order sequence modeling and the handling of uncertainty.
MLJan 6, 2019
Solving L1-regularized SVMs and related linear programs: Revisiting the effectiveness of Column and Constraint GenerationAntoine Dedieu, Rahul Mazumder, Haoyue Wang
The linear Support Vector Machine (SVM) is a classic classification technique in machine learning. Motivated by applications in modern high dimensional statistics, we consider penalized SVM problems involving the minimization of a hinge-loss function with a convex sparsity-inducing regularizer such as: the L1-norm on the coefficients, its grouped generalization and the sorted L1-penalty (aka Slope). Each problem can be expressed as a Linear Program (LP) and is computationally challenging when the number of features and/or samples is large -- the current state of algorithms for these problems is rather nascent when compared to the usual L2-regularized linear SVM. To this end, we propose new computational algorithms for these LPs by bringing together techniques from (a) classical column (and constraint) generation methods and (b) first order methods for non-smooth convex optimization -- techniques that are rarely used together for solving large scale LPs. These components have their respective strengths; and while they are found to be useful as separate entities, they have not been used together in the context of solving large scale LPs such as the ones studied herein. Our approach complements the strengths of (a) and (b) -- leading to a scheme that seems to significantly outperform commercial solvers as well as specialized implementations for these problems. We present numerical results on a series of real and synthetic datasets demonstrating the surprising effectiveness of classic column/constraint generation methods in the context of challenging LP-based machine learning tasks.
MLMar 4, 2018
Hierarchical Modeling and Shrinkage for User Session Length Prediction in Media StreamingAntoine Dedieu, Rahul Mazumder, Zhen Zhu et al.
An important metric of users' satisfaction and engagement within on-line streaming services is the user session length, i.e. the amount of time they spend on a service continuously without interruption. Being able to predict this value directly benefits the recommendation and ad pacing contexts in music and video streaming services. Recent research has shown that predicting the exact amount of time spent is highly nontrivial due to many external factors for which a user can end a session, and the lack of predictive covariates. Most of the other related literature on duration based user engagement has focused on dwell time for websites, for search and display ads, mainly for post-click satisfaction prediction or ad ranking. In this work we present a novel framework inspired by hierarchical Bayesian modeling to predict, at the moment of login, the amount of time a user will spend in the streaming service. The time spent by a user on a platform depends upon user-specific latent variables which are learned via hierarchical shrinkage. Our framework enjoys theoretical guarantees and naturally incorporates flexible parametric/nonparametric models on the covariates, including models robust to outliers. Our proposal is found to outperform state-of- the-art estimators in terms of efficiency and predictive performance on real world public and private datasets.
MEAug 10, 2017
Subset Selection with Shrinkage: Sparse Linear Modeling when the SNR is lowRahul Mazumder, Peter Radchenko, Antoine Dedieu
We study a seemingly unexpected and relatively less understood overfitting aspect of a fundamental tool in sparse linear modeling - best subset selection, which minimizes the residual sum of squares subject to a constraint on the number of nonzero coefficients. While the best subset selection procedure is often perceived as the "gold standard" in sparse learning when the signal to noise ratio (SNR) is high, its predictive performance deteriorates when the SNR is low. In particular, it is outperformed by continuous shrinkage methods, such as ridge regression and the Lasso. We investigate the behavior of best subset selection in the high-noise regimes and propose an alternative approach based on a regularized version of the least-squares criterion. Our proposed estimators (a) mitigate, to a large extent, the poor predictive performance of best subset selection in the high-noise regimes; and (b) perform favorably, while generally delivering substantially sparser models, relative to the best predictive models available via ridge regression and the Lasso. We conduct an extensive theoretical analysis of the predictive properties of the proposed approach and provide justification for its superior predictive performance relative to best subset selection when the noise-level is high. Our estimators can be expressed as solutions to mixed integer second order conic optimization problems and, hence, are amenable to modern computational tools from mathematical optimization.