LGDec 20, 2022Code
Model Ratatouille: Recycling Diverse Models for Out-of-Distribution GeneralizationAlexandre Ramé, Kartik Ahuja, Jianyu Zhang et al.
Foundation models are redefining how AI systems are built. Practitioners now follow a standard procedure to build their machine learning solutions: from a pre-trained foundation model, they fine-tune the weights on the target task of interest. So, the Internet is swarmed by a handful of foundation models fine-tuned on many diverse tasks: these individual fine-tunings exist in isolation without benefiting from each other. In our opinion, this is a missed opportunity, as these specialized models contain rich and diverse features. In this paper, we thus propose model ratatouille, a new strategy to recycle the multiple fine-tunings of the same foundation model on diverse auxiliary tasks. Specifically, we repurpose these auxiliary weights as initializations for multiple parallel fine-tunings on the target task; then, we average all fine-tuned weights to obtain the final model. This recycling strategy aims at maximizing the diversity in weights by leveraging the diversity in auxiliary tasks. Empirically, it improves the state of the art on the reference DomainBed benchmark for out-of-distribution generalization. Looking forward, this work contributes to the emerging paradigm of updatable machine learning where, akin to open-source software development, the community collaborates to reliably update machine learning models. Our code is released: https://github.com/facebookresearch/ModelRatatouille.
LGMar 18, 2022Code
WOODS: Benchmarks for Out-of-Distribution Generalization in Time SeriesJean-Christophe Gagnon-Audet, Kartik Ahuja, Mohammad-Javad Darvishi-Bayazi et al.
Machine learning models often fail to generalize well under distributional shifts. Understanding and overcoming these failures have led to a research field of Out-of-Distribution (OOD) generalization. Despite being extensively studied for static computer vision tasks, OOD generalization has been underexplored for time series tasks. To shine light on this gap, we present WOODS: eight challenging open-source time series benchmarks covering a diverse range of data modalities, such as videos, brain recordings, and sensor signals. We revise the existing OOD generalization algorithms for time series tasks and evaluate them using our systematic framework. Our experiments show a large room for improvement for empirical risk minimization and OOD generalization algorithms on our datasets, thus underscoring the new challenges posed by time series tasks. Code and documentation are available at https://woods-benchmarks.github.io .
LGJun 2, 2022
Weakly Supervised Representation Learning with Sparse PerturbationsKartik Ahuja, Jason Hartford, Yoshua Bengio
The theory of representation learning aims to build methods that provably invert the data generating process with minimal domain knowledge or any source of supervision. Most prior approaches require strong distributional assumptions on the latent variables and weak supervision (auxiliary information such as timestamps) to provide provable identification guarantees. In this work, we show that if one has weak supervision from observations generated by sparse perturbations of the latent variables--e.g. images in a reinforcement learning environment where actions move individual sprites--identification is achievable under unknown continuous latent distributions. We show that if the perturbations are applied only on mutually exclusive blocks of latents, we identify the latents up to those blocks. We also show that if these perturbation blocks overlap, we identify latents up to the smallest blocks shared across perturbations. Consequently, if there are blocks that intersect in one latent variable only, then such latents are identified up to permutation and scaling. We propose a natural estimation procedure based on this theory and illustrate it on low-dimensional synthetic and image-based experiments.
MLSep 24, 2022
Interventional Causal Representation LearningKartik Ahuja, Divyat Mahajan, Yixin Wang et al.
Causal representation learning seeks to extract high-level latent factors from low-level sensory data. Most existing methods rely on observational data and structural assumptions (e.g., conditional independence) to identify the latent factors. However, interventional data is prevalent across applications. Can interventional data facilitate causal representation learning? We explore this question in this paper. The key observation is that interventional data often carries geometric signatures of the latent factors' support (i.e. what values each latent can possibly take). For example, when the latent factors are causally connected, interventions can break the dependency between the intervened latents' support and their ancestors'. Leveraging this fact, we prove that the latent causal factors can be identified up to permutation and scaling given data from perfect $do$ interventions. Moreover, we can achieve block affine identification, namely the estimated latent factors are only entangled with a few other latents if we have access to data from imperfect interventions. These results highlight the unique power of interventional data in causal representation learning; they can enable provable identification of latent factors without any assumptions about their distributions or dependency structure.
90.7CVJun 3
Who Needs Labels? Adapting Vision Foundation Models With the Metadata You Already HaveElouan Gardès, Seung Eun Yi, Kartik Ahuja et al.
We propose a label-free approach to adapt powerful but generic vision foundation models to specialized scientific domains. Standard supervised fine-tuning is often ill-suited to these settings: labels are scarce, and task-specific training can collapse the model's generality and hurt robustness. We instead leverage metadata to adapt representations to new domains in a self-supervised manner. Our method, FINO, combines a standard self-supervised objective with flexible metadata guidance that handles both highly granular discrete metadata and continuous metadata. It encourages the representation to preserve informative factors while suppressing spurious ones. Across subcellular fluorescence microscopy, Earth observation, wildlife monitoring, and medical imaging, FINO consistently outperforms standard unsupervised domain adaptation and fully supervised adaptation. It also exceeds highly-specialized domain-specific state of the art, while using no task labels for backbone adaptation and only lightweight probes for supervision.
LGJun 28, 2023
On the Identifiability of Quantized FactorsVitória Barin-Pacela, Kartik Ahuja, Simon Lacoste-Julien et al. · mila
Disentanglement aims to recover meaningful latent ground-truth factors from the observed distribution solely, and is formalized through the theory of identifiability. The identifiability of independent latent factors is proven to be impossible in the unsupervised i.i.d. setting under a general nonlinear map from factors to observations. In this work, however, we demonstrate that it is possible to recover quantized latent factors under a generic nonlinear diffeomorphism. We only assume that the latent factors have independent discontinuities in their density, without requiring the factors to be statistically independent. We introduce this novel form of identifiability, termed quantized factor identifiability, and provide a comprehensive proof of the recovery of the quantized factors.
LGOct 31, 2022
FL Games: A Federated Learning Framework for Distribution ShiftsSharut Gupta, Kartik Ahuja, Mohammad Havaei et al.
Federated learning aims to train predictive models for data that is distributed across clients, under the orchestration of a server. However, participating clients typically each hold data from a different distribution, which can yield to catastrophic generalization on data from a different client, which represents a new domain. In this work, we argue that in order to generalize better across non-i.i.d. clients, it is imperative to only learn correlations that are stable and invariant across domains. We propose FL GAMES, a game-theoretic framework for federated learning that learns causal features that are invariant across clients. While training to achieve the Nash equilibrium, the traditional best response strategy suffers from high-frequency oscillations. We demonstrate that FL GAMES effectively resolves this challenge and exhibits smooth performance curves. Further, FL GAMES scales well in the number of clients, requires significantly fewer communication rounds, and is agnostic to device heterogeneity. Through empirical evaluation, we demonstrate that FL GAMES achieves high out-of-distribution performance on various benchmarks.
LGMay 23, 2022
FL Games: A federated learning framework for distribution shiftsSharut Gupta, Kartik Ahuja, Mohammad Havaei et al.
Federated learning aims to train predictive models for data that is distributed across clients, under the orchestration of a server. However, participating clients typically each hold data from a different distribution, whereby predictive models with strong in-distribution generalization can fail catastrophically on unseen domains. In this work, we argue that in order to generalize better across non-i.i.d. clients, it is imperative to only learn correlations that are stable and invariant across domains. We propose FL Games, a game-theoretic framework for federated learning for learning causal features that are invariant across clients. While training to achieve the Nash equilibrium, the traditional best response strategy suffers from high-frequency oscillations. We demonstrate that FL Games effectively resolves this challenge and exhibits smooth performance curves. Further, FL Games scales well in the number of clients, requires significantly fewer communication rounds, and is agnostic to device heterogeneity. Through empirical evaluation, we demonstrate that FL Games achieves high out-of-distribution performance on various benchmarks.
MLMay 23, 2022
Why does Throwing Away Data Improve Worst-Group Error?Kamalika Chaudhuri, Kartik Ahuja, Martin Arjovsky et al.
When facing data with imbalanced classes or groups, practitioners follow an intriguing strategy to achieve best results. They throw away examples until the classes or groups are balanced in size, and then perform empirical risk minimization on the reduced training set. This opposes common wisdom in learning theory, where the expected error is supposed to decrease as the dataset grows in size. In this work, we leverage extreme value theory to address this apparent contradiction. Our results show that the tails of the data distribution play an important role in determining the worst-group-accuracy of linear classifiers. When learning on data with heavy tails, throwing away data restores the geometric symmetry of the resulting classifier, and therefore improves its worst-group generalization.
LGApr 10, 2022
Towards efficient representation identification in supervised learningKartik Ahuja, Divyat Mahajan, Vasilis Syrgkanis et al.
Humans have a remarkable ability to disentangle complex sensory inputs (e.g., image, text) into simple factors of variation (e.g., shape, color) without much supervision. This ability has inspired many works that attempt to solve the following question: how do we invert the data generation process to extract those factors with minimal or no supervision? Several works in the literature on non-linear independent component analysis have established this negative result; without some knowledge of the data generation process or appropriate inductive biases, it is impossible to perform this inversion. In recent years, a lot of progress has been made on disentanglement under structural assumptions, e.g., when we have access to auxiliary information that makes the factors of variation conditionally independent. However, existing work requires a lot of auxiliary information, e.g., in supervised classification, it prescribes that the number of label classes should be at least equal to the total dimension of all factors of variation. In this work, we depart from these assumptions and ask: a) How can we get disentanglement when the auxiliary information does not provide conditional independence over the factors of variation? b) Can we reduce the amount of auxiliary information required for disentanglement? For a class of models where auxiliary information does not ensure conditional independence, we show theoretically and experimentally that disentanglement (to a large extent) is possible even when the auxiliary information dimension is much less than the dimension of the true latent representation.
AIFeb 21, 2023
Reusable Slotwise MechanismsTrang Nguyen, Amin Mansouri, Kanika Madan et al.
Agents with the ability to comprehend and reason about the dynamics of objects would be expected to exhibit improved robustness and generalization in novel scenarios. However, achieving this capability necessitates not only an effective scene representation but also an understanding of the mechanisms governing interactions among object subsets. Recent studies have made significant progress in representing scenes using object slots. In this work, we introduce Reusable Slotwise Mechanisms, or RSM, a framework that models object dynamics by leveraging communication among slots along with a modular architecture capable of dynamically selecting reusable mechanisms for predicting the future states of each object slot. Crucially, RSM leverages the Central Contextual Information (CCI), enabling selected mechanisms to access the remaining slots through a bottleneck, effectively allowing for modeling of higher order and complex interactions that might require a sparse subset of objects. Experimental results demonstrate the superior performance of RSM compared to state-of-the-art methods across various future prediction and related downstream tasks, including Visual Question Answering and action planning. Furthermore, we showcase RSM's Out-of-Distribution generalization ability to handle scenes in intricate scenarios.
LGNov 15, 2022
Empirical Study on Optimizer Selection for Out-of-Distribution GeneralizationHiroki Naganuma, Kartik Ahuja, Shiro Takagi et al.
Modern deep learning systems do not generalize well when the test data distribution is slightly different to the training data distribution. While much promising work has been accomplished to address this fragility, a systematic study of the role of optimizers and their out-of-distribution generalization performance has not been undertaken. In this study, we examine the performance of popular first-order optimizers for different classes of distributional shift under empirical risk minimization and invariant risk minimization. We address this question for image and text classification using DomainBed, WILDS, and Backgrounds Challenge as testbeds for studying different types of shifts -- namely correlation and diversity shift. We search over a wide range of hyperparameters and examine classification accuracy (in-distribution and out-of-distribution) for over 20,000 models. We arrive at the following findings, which we expect to be helpful for practitioners: i) adaptive optimizers (e.g., Adam) perform worse than non-adaptive optimizers (e.g., SGD, momentum SGD) on out-of-distribution performance. In particular, even though there is no significant difference in in-distribution performance, we show a measurable difference in out-of-distribution performance. ii) in-distribution performance and out-of-distribution performance exhibit three types of behavior depending on the dataset -- linear returns, increasing returns, and diminishing returns. For example, in the training of natural language data using Adam, fine-tuning the performance of in-distribution performance does not significantly contribute to the out-of-distribution generalization performance.
LGOct 4, 2023
Multi-Domain Causal Representation Learning via Weak Distributional InvariancesKartik Ahuja, Amin Mansouri, Yixin Wang
Causal representation learning has emerged as the center of action in causal machine learning research. In particular, multi-domain datasets present a natural opportunity for showcasing the advantages of causal representation learning over standard unsupervised representation learning. While recent works have taken crucial steps towards learning causal representations, they often lack applicability to multi-domain datasets due to over-simplifying assumptions about the data; e.g. each domain comes from a different single-node perfect intervention. In this work, we relax these assumptions and capitalize on the following observation: there often exists a subset of latents whose certain distributional properties (e.g., support, variance) remain stable across domains; this property holds when, for example, each domain comes from a multi-node imperfect intervention. Leveraging this observation, we show that autoencoders that incorporate such invariances can provably identify the stable set of latents from the rest across different settings.
LGFeb 2
ReasonCACHE: Teaching LLMs To Reason Without Weight UpdatesSharut Gupta, Phillip Isola, Stefanie Jegelka et al.
Can Large language models (LLMs) learn to reason without any weight update and only through in-context learning (ICL)? ICL is strikingly sample-efficient, often learning from only a handful of demonstrations, but complex reasoning tasks typically demand many training examples to learn from. However, naively scaling ICL by adding more demonstrations breaks down at this scale: attention costs grow quadratically, performance saturates or degrades with longer contexts, and the approach remains a shallow form of learning. Due to these limitations, practitioners predominantly rely on in-weight learning (IWL) to induce reasoning. In this work, we show that by using Prefix Tuning, LLMs can learn to reason without overloading the context window and without any weight updates. We introduce $\textbf{ReasonCACHE}$, an instantiation of this mechanism that distills demonstrations into a fixed key-value cache. Empirically, across challenging reasoning benchmarks, including GPQA-Diamond, ReasonCACHE outperforms standard ICL and matches or surpasses IWL approaches. Further, it achieves this all while being more efficient across three key axes: data, inference cost, and trainable parameters. We also theoretically prove that ReasonCACHE can be strictly more expressive than low-rank weight update since the latter ties expressivity to input rank, whereas ReasonCACHE bypasses this constraint by directly injecting key-values into the attention mechanism. Together, our findings identify ReasonCACHE as a middle path between in-context and in-weight learning, providing a scalable algorithm for learning reasoning skills beyond the context window without modifying parameters. Our project page: https://reasoncache.github.io/
LGSep 18, 2023
Context is EnvironmentSharut Gupta, Stefanie Jegelka, David Lopez-Paz et al.
Two lines of work are taking the central stage in AI research. On the one hand, the community is making increasing efforts to build models that discard spurious correlations and generalize better in novel test environments. Unfortunately, the bitter lesson so far is that no proposal convincingly outperforms a simple empirical risk minimization baseline. On the other hand, large language models (LLMs) have erupted as algorithms able to learn in-context, generalizing on-the-fly to eclectic contextual circumstances that users enforce by means of prompting. In this paper, we argue that context is environment, and posit that in-context learning holds the key to better domain generalization. Via extensive theory and experiments, we show that paying attention to context$\unicode{x2013}\unicode{x2013}$unlabeled examples as they arrive$\unicode{x2013}\unicode{x2013}$allows our proposed In-Context Risk Minimization (ICRM) algorithm to zoom-in on the test environment risk minimizer, leading to significant out-of-distribution performance improvements. From all of this, two messages are worth taking home. Researchers in domain generalization should consider environment as context, and harness the adaptive power of in-context learning. Researchers in LLMs should consider context as environment, to better structure data towards generalization.
LGJan 26
Teaching Models to Teach Themselves: Reasoning at the Edge of LearnabilityShobhita Sundaram, John Quan, Ariel Kwiatkowski et al.
Can a model learn to escape its own learning plateau? Reinforcement learning methods for finetuning large reasoning models stall on datasets with low initial success rates, and thus little training signal. We investigate a fundamental question: Can a pretrained LLM leverage latent knowledge to generate an automated curriculum for problems it cannot solve? To explore this, we design SOAR: A self-improvement framework designed to surface these pedagogical signals through meta-RL. A teacher copy of the model proposes synthetic problems for a student copy, and is rewarded with its improvement on a small subset of hard problems. Critically, SOAR grounds the curriculum in measured student progress rather than intrinsic proxy rewards. Our study on the hardest subsets of mathematical benchmarks (0/128 success) reveals three core findings. First, we show that it is possible to realize bi-level meta-RL that unlocks learning under sparse, binary rewards by sharpening a latent capacity of pretrained models to generate useful stepping stones. Second, grounded rewards outperform intrinsic reward schemes used in prior LLM self-play, reliably avoiding the instability and diversity collapse modes they typically exhibit. Third, analyzing the generated questions reveals that structural quality and well-posedness are more critical for learning progress than solution correctness. Our results suggest that the ability to generate useful stepping stones does not require the preexisting ability to actually solve the hard problems, paving a principled path to escape reasoning plateaus without additional curated data.
LGJun 8, 2020Code
Adversarial Feature DesensitizationPouya Bashivan, Reza Bayat, Adam Ibrahim et al.
Neural networks are known to be vulnerable to adversarial attacks -- slight but carefully constructed perturbations of the inputs which can drastically impair the network's performance. Many defense methods have been proposed for improving robustness of deep networks by training them on adversarially perturbed inputs. However, these models often remain vulnerable to new types of attacks not seen during training, and even to slightly stronger versions of previously seen attacks. In this work, we propose a novel approach to adversarial robustness, which builds upon the insights from the domain adaptation field. Our method, called Adversarial Feature Desensitization (AFD), aims at learning features that are invariant towards adversarial perturbations of the inputs. This is achieved through a game where we learn features that are both predictive and robust (insensitive to adversarial attacks), i.e. cannot be used to discriminate between natural and adversarial data. Empirical results on several benchmarks demonstrate the effectiveness of the proposed approach against a wide range of attack types and attack strengths. Our code is available at https://github.com/BashivanLab/afd.
LGFeb 7, 2024
On Provable Length and Compositional GeneralizationKartik Ahuja, Amin Mansouri
Out-of-distribution generalization capabilities of sequence-to-sequence models can be studied from the lens of two crucial forms of generalization: length generalization -- the ability to generalize to longer sequences than ones seen during training, and compositional generalization: the ability to generalize to token combinations not seen during training. In this work, we provide first provable guarantees on length and compositional generalization for common sequence-to-sequence models -- deep sets, transformers, state space models, and recurrent neural nets -- trained to minimize the prediction error. We show that \emph{limited capacity} versions of these different architectures achieve both length and compositional generalization provided the training distribution is sufficiently diverse. In the first part, we study structured limited capacity variants of different architectures and arrive at the generalization guarantees with limited diversity requirements on the training distribution. In the second part, we study limited capacity variants with less structural assumptions and arrive at generalization guarantees but with more diversity requirements on the training distribution. Further, we also show that chain-of-thought supervision enables length generalization in higher capacity counterparts of the different architectures we study.
CLFeb 11, 2025
Unveiling Simplicities of Attention: Adaptive Long-Context Head IdentificationKonstantin Donhauser, Charles Arnal, Mohammad Pezeshki et al.
The ability to process long contexts is crucial for many natural language processing tasks, yet it remains a significant challenge. While substantial progress has been made in enhancing the efficiency of attention mechanisms, there is still a gap in understanding how attention heads function in long-context settings. In this paper, we observe that while certain heads consistently attend to local information only, others swing between attending to local and long-context information depending on the query. This raises the question: can we identify which heads require long-context information to predict the next token accurately? We demonstrate that it's possible to predict which heads are crucial for long-context processing using only local keys. The core idea here is to exploit a simple model for the long-context scores via second moment approximations. These findings unveil simple properties of attention in the context of long sequences, and open the door to potentially significant gains in efficiency.
LGApr 8, 2024
DRoP: Distributionally Robust Data PruningArtem Vysogorets, Kartik Ahuja, Julia Kempe
In the era of exceptionally data-hungry models, careful selection of the training data is essential to mitigate the extensive costs of deep learning. Data pruning offers a solution by removing redundant or uninformative samples from the dataset, which yields faster convergence and improved neural scaling laws. However, little is known about its impact on classification bias of the trained models. We conduct the first systematic study of this effect and reveal that existing data pruning algorithms can produce highly biased classifiers. We present theoretical analysis of the classification risk in a mixture of Gaussians to argue that choosing appropriate class pruning ratios, coupled with random pruning within classes has potential to improve worst-class performance. We thus propose DRoP, a distributionally robust approach to pruning and empirically demonstrate its performance on standard computer vision benchmarks. In sharp contrast to existing algorithms, our proposed method continues improving distributional robustness at a tolerable drop of average performance as we prune more from the datasets.
LGSep 1, 2025
Distilled Pretraining: A modern lens of Data, In-Context Learning and Test-Time ScalingSachin Goyal, David Lopez-Paz, Kartik Ahuja
In the past year, distillation has seen a renewed prominence in large language model (LLM) pretraining, exemplified by the Llama-3.2 and Gemma model families. While distillation has historically been shown to improve statistical modeling, its effects on new paradigms that are key to modern LLMs, such as test-time scaling and in-context learning, remain underexplored. In this work, we make three main contributions. First, we show that pretraining with distillation yields models that exhibit remarkably better test-time scaling. Second, we observe that this benefit comes with a trade-off: distillation impairs in-context learning capabilities, particularly the one modeled via induction heads. Third, to demystify these findings, we study distilled pretraining in a sandbox of a bigram model, which helps us isolate the common principal factor behind our observations. Finally, using these insights, we shed light on various design choices for pretraining that should help practitioners going forward.
LGNov 25, 2025
Operationalizing Quantized DisentanglementVitoria Barin-Pacela, Kartik Ahuja, Simon Lacoste-Julien et al.
Recent theoretical work established the unsupervised identifiability of quantized factors under any diffeomorphism. The theory assumes that quantization thresholds correspond to axis-aligned discontinuities in the probability density of the latent factors. By constraining a learned map to have a density with axis-aligned discontinuities, we can recover the quantization of the factors. However, translating this high-level principle into an effective practical criterion remains challenging, especially under nonlinear maps. Here, we develop a criterion for unsupervised disentanglement by encouraging axis-aligned discontinuities. Discontinuities manifest as sharp changes in the estimated density of factors and form what we call cliffs. Following the definition of independent discontinuities from the theory, we encourage the location of the cliffs along a factor to be independent of the values of the other factors. We show that our method, Cliff, outperforms the baselines on all disentanglement benchmarks, demonstrating its effectiveness in unsupervised disentanglement.
LGOct 16, 2025
Beyond Multi-Token Prediction: Pretraining LLMs with Future SummariesDivyat Mahajan, Sachin Goyal, Badr Youbi Idrissi et al. · meta-ai
Next-token prediction (NTP) has driven the success of large language models (LLMs), but it struggles with long-horizon reasoning, planning, and creative writing, with these limitations largely attributed to teacher-forced training. Multi-token prediction (MTP) partially mitigates these issues by predicting several future tokens at once, but it mostly captures short-range dependencies and offers limited improvement. We propose future summary prediction (FSP), which trains an auxiliary head to predict a compact representation of the long-term future, preserving information relevant for long-form generations. We explore two variants of FSP: handcrafted summaries, for example, a bag of words summary of the future of the sequence, and learned summaries, which use embeddings produced by a reverse language model trained from right to left. Large-scale pretraining experiments (3B and 8B-parameter models) demonstrate that FSP provides improvements over both NTP and MTP across math, reasoning, and coding benchmarks.
LGMay 26, 2023
A Closer Look at In-Context Learning under Distribution ShiftsKartik Ahuja, David Lopez-Paz
In-context learning, a capability that enables a model to learn from input examples on the fly without necessitating weight updates, is a defining characteristic of large language models. In this work, we follow the setting proposed in (Garg et al., 2022) to better understand the generality and limitations of in-context learning from the lens of the simple yet fundamental task of linear regression. The key question we aim to address is: Are transformers more adept than some natural and simpler architectures at performing in-context learning under varying distribution shifts? To compare transformers, we propose to use a simple architecture based on set-based Multi-Layer Perceptrons (MLPs). We find that both transformers and set-based MLPs exhibit in-context learning under in-distribution evaluations, but transformers more closely emulate the performance of ordinary least squares (OLS). Transformers also display better resilience to mild distribution shifts, where set-based MLPs falter. However, under severe distribution shifts, both models' in-context learning abilities diminish.
LGJan 28, 2022
Locally Invariant Explanations: Towards Stable and Unidirectional Explanations through Local Invariant LearningAmit Dhurandhar, Karthikeyan Ramamurthy, Kartik Ahuja et al.
Locally interpretable model agnostic explanations (LIME) method is one of the most popular methods used to explain black-box models at a per example level. Although many variants have been proposed, few provide a simple way to produce high fidelity explanations that are also stable and intuitive. In this work, we provide a novel perspective by proposing a model agnostic local explanation method inspired by the invariant risk minimization (IRM) principle -- originally proposed for (global) out-of-distribution generalization -- to provide such high fidelity explanations that are also stable and unidirectional across nearby examples. Our method is based on a game theoretic formulation where we theoretically show that our approach has a strong tendency to eliminate features where the gradient of the black-box function abruptly changes sign in the locality of the example we want to explain, while in other cases it is more careful and will choose a more conservative (feature) attribution, a behavior which can be highly desirable for recourse. Empirically, we show on tabular, image and text data that the quality of our explanations with neighborhoods formed using random perturbations are much better than LIME and in some cases even comparable to other methods that use realistic neighbors sampled from the data manifold. This is desirable given that learning a manifold to either create realistic neighbors or to project explanations is typically expensive or may even be impossible. Moreover, our algorithm is simple and efficient to train, and can ascertain stable input features for local decisions of a black-box without access to side information such as a (partial) causal graph as has been seen in some recent works.
LGOct 29, 2021
Properties from Mechanisms: An Equivariance Perspective on Identifiable Representation LearningKartik Ahuja, Jason Hartford, Yoshua Bengio
A key goal of unsupervised representation learning is "inverting" a data generating process to recover its latent properties. Existing work that provably achieves this goal relies on strong assumptions on relationships between the latent variables (e.g., independence conditional on auxiliary information). In this paper, we take a very different perspective on the problem and ask, "Can we instead identify latent properties by leveraging knowledge of the mechanisms that govern their evolution?" We provide a complete characterization of the sources of non-identifiability as we vary knowledge about a set of possible mechanisms. In particular, we prove that if we know the exact mechanisms under which the latent properties evolve, then identification can be achieved up to any equivariances that are shared by the underlying mechanisms. We generalize this characterization to settings where we only know some hypothesis class over possible mechanisms, as well as settings where the mechanisms are stochastic. We demonstrate the power of this mechanism-based perspective by showing that we can leverage our results to generalize existing identifiable representation learning results. These results suggest that by exploiting inductive biases on mechanisms, it is possible to design a range of new identifiable representation learning approaches.
LGJun 22, 2021
Finding Valid Adjustments under Non-ignorability with Minimal DAG KnowledgeAbhin Shah, Karthikeyan Shanmugam, Kartik Ahuja
Treatment effect estimation from observational data is a fundamental problem in causal inference. There are two very different schools of thought that have tackled this problem. On one hand, Pearlian framework commonly assumes structural knowledge (provided by an expert) in form of directed acyclic graphs and provides graphical criteria such as back-door criterion to identify valid adjustment sets. On other hand, potential outcomes (PO) framework commonly assumes that all observed features satisfy ignorability (i.e., no hidden confounding), which in general is untestable. In prior works that attempted to bridge these frameworks, there is an observational criteria to identify an anchor variable and if a subset of covariates (not involving the anchor variable) passes a suitable conditional independence criteria, then that subset is a valid back-door. Our main result strengthens these prior results by showing that under a different expert-driven structural knowledge -- that one variable is a direct causal parent of treatment variable -- remarkably, testing for subsets (not involving the known parent variable) that are valid back-doors is equivalent to an invariance test. Importantly, we also cover the non-trivial case where entire set of observed features is not ignorable (generalizing the PO framework) without requiring knowledge of all parents of treatment variable. Our key technical idea involves generation of a synthetic sub-sampling (or environment) variable that is a function of the known parent variable. In addition to designing an invariance test, this sub-sampling variable allows us to leverage Invariant Risk Minimization, and thus, connects finding valid adjustments (in non-ignorable observational setting) to representation learning. We demonstrate effectiveness and tradeoffs of our approaches on a variety of synthetic data as well as real causal effect estimation benchmarks.
LGJun 11, 2021
Invariance Principle Meets Information Bottleneck for Out-of-Distribution GeneralizationKartik Ahuja, Ethan Caballero, Dinghuai Zhang et al.
The invariance principle from causality is at the heart of notable approaches such as invariant risk minimization (IRM) that seek to address out-of-distribution (OOD) generalization failures. Despite the promising theory, invariance principle-based approaches fail in common classification tasks, where invariant (causal) features capture all the information about the label. Are these failures due to the methods failing to capture the invariance? Or is the invariance principle itself insufficient? To answer these questions, we revisit the fundamental assumptions in linear regression tasks, where invariance-based approaches were shown to provably generalize OOD. In contrast to the linear regression tasks, we show that for linear classification tasks we need much stronger restrictions on the distribution shifts, or otherwise OOD generalization is impossible. Furthermore, even with appropriate restrictions on distribution shifts in place, we show that the invariance principle alone is insufficient. We prove that a form of the information bottleneck constraint along with invariance helps address key failures when invariant features capture all the information about the label and also retains the existing success when they do not. We propose an approach that incorporates both of these principles and demonstrate its effectiveness in several experiments.
LGJun 5, 2021
Can Subnetwork Structure be the Key to Out-of-Distribution Generalization?Dinghuai Zhang, Kartik Ahuja, Yilun Xu et al.
Can models with particular structure avoid being biased towards spurious correlation in out-of-distribution (OOD) generalization? Peters et al. (2016) provides a positive answer for linear cases. In this paper, we use a functional modular probing method to analyze deep model structures under OOD setting. We demonstrate that even in biased models (which focus on spurious correlation) there still exist unbiased functional subnetworks. Furthermore, we articulate and demonstrate the functional lottery ticket hypothesis: full network contains a subnetwork that can achieve better OOD performance. We then propose Modular Risk Minimization to solve the subnetwork selection problem. Our algorithm learns the subnetwork structure from a given dataset, and can be combined with any other OOD regularization methods. Experiments on various OOD generalization tasks corroborate the effectiveness of our method.
LGJun 4, 2021
SAND-mask: An Enhanced Gradient Masking Strategy for the Discovery of Invariances in Domain GeneralizationSoroosh Shahtalebi, Jean-Christophe Gagnon-Audet, Touraj Laleh et al.
A major bottleneck in the real-world applications of machine learning models is their failure in generalizing to unseen domains whose data distribution is not i.i.d to the training domains. This failure often stems from learning non-generalizable features in the training domains that are spuriously correlated with the label of data. To address this shortcoming, there has been a growing surge of interest in learning good explanations that are hard to vary, which is studied under the notion of Out-of-Distribution (OOD) Generalization. The search for good explanations that are \textit{invariant} across different domains can be seen as finding local (global) minimas in the loss landscape that hold true across all of the training domains. In this paper, we propose a masking strategy, which determines a continuous weight based on the agreement of gradients that flow in each edge of network, in order to control the amount of update received by the edge in each step of optimization. Particularly, our proposed technique referred to as "Smoothed-AND (SAND)-masking", not only validates the agreement in the direction of gradients but also promotes the agreement among their magnitudes to further ensure the discovery of invariances across training domains. SAND-mask is validated over the Domainbed benchmark for domain generalization and significantly improves the state-of-the-art accuracy on the Colored MNIST dataset while providing competitive results on other domain generalization datasets.
LGMar 13, 2021
Treatment Effect Estimation using Invariant Risk MinimizationAbhin Shah, Kartik Ahuja, Karthikeyan Shanmugam et al.
Inferring causal individual treatment effect (ITE) from observational data is a challenging problem whose difficulty is exacerbated by the presence of treatment assignment bias. In this work, we propose a new way to estimate the ITE using the domain generalization framework of invariant risk minimization (IRM). IRM uses data from multiple domains, learns predictors that do not exploit spurious domain-dependent factors, and generalizes better to unseen domains. We propose an IRM-based ITE estimator aimed at tackling treatment assignment bias when there is little support overlap between the control group and the treatment group. We accomplish this by creating diversity: given a single dataset, we split the data into multiple domains artificially. These diverse domains are then exploited by IRM to more effectively generalize regression-based models to data regions that lack support overlap. We show gains over classical regression approaches to ITE estimation in settings when support mismatch is more pronounced.
LGDec 22, 2020
Learning to Initialize Gradient Descent Using Gradient DescentKartik Ahuja, Amit Dhurandhar, Kush R. Varshney
Non-convex optimization problems are challenging to solve; the success and computational expense of a gradient descent algorithm or variant depend heavily on the initialization strategy. Often, either random initialization is used or initialization rules are carefully designed by exploiting the nature of the problem class. As a simple alternative to hand-crafted initialization rules, we propose an approach for learning "good" initialization rules from previous solutions. We provide theoretical guarantees that establish conditions that are sufficient in all cases and also necessary in some under which our approach performs better than random initialization. We apply our methodology to various non-convex problems such as generating adversarial examples, generating post hoc explanations for black-box machine learning models, and allocating communication spectrum, and show consistent gains over other initialization techniques.
LGOct 30, 2020
Empirical or Invariant Risk Minimization? A Sample Complexity PerspectiveKartik Ahuja, Jun Wang, Amit Dhurandhar et al.
Recently, invariant risk minimization (IRM) was proposed as a promising solution to address out-of-distribution (OOD) generalization. However, it is unclear when IRM should be preferred over the widely-employed empirical risk minimization (ERM) framework. In this work, we analyze both these frameworks from the perspective of sample complexity, thus taking a firm step towards answering this important question. We find that depending on the type of data generation mechanism, the two approaches might have very different finite sample and asymptotic behavior. For example, in the covariate shift setting we see that the two approaches not only arrive at the same asymptotic solution, but also have similar finite sample behavior with no clear winner. For other distribution shifts such as those involving confounders or anti-causal variables, however, the two approaches arrive at different asymptotic solutions where IRM is guaranteed to be close to the desired OOD solutions in the finite sample regime, while ERM is biased even asymptotically. We further investigate how different factors -- the number of environments, complexity of the model, and IRM penalty weight -- impact the sample complexity of IRM in relation to its distance from the OOD solutions
LGOct 28, 2020
Linear Regression Games: Convergence Guarantees to Approximate Out-of-Distribution SolutionsKartik Ahuja, Karthikeyan Shanmugam, Amit Dhurandhar
Recently, invariant risk minimization (IRM) (Arjovsky et al.) was proposed as a promising solution to address out-of-distribution (OOD) generalization. In Ahuja et al., it was shown that solving for the Nash equilibria of a new class of "ensemble-games" is equivalent to solving IRM. In this work, we extend the framework in Ahuja et al. for linear regressions by projecting the ensemble-game on an $\ell_{\infty}$ ball. We show that such projections help achieve non-trivial OOD guarantees despite not achieving perfect invariance. For linear models with confounders, we prove that Nash equilibria of these games are closer to the ideal OOD solutions than the standard empirical risk minimization (ERM) and we also provide learning algorithms that provably converge to these Nash Equilibria. Empirical comparisons of the proposed approach with the state-of-the-art show consistent gains in achieving OOD solutions in several settings involving anti-causal variables and confounders.
LGFeb 11, 2020
Invariant Risk Minimization GamesKartik Ahuja, Karthikeyan Shanmugam, Kush R. Varshney et al.
The standard risk minimization paradigm of machine learning is brittle when operating in environments whose test distributions are different from the training distribution due to spurious correlations. Training on data from many environments and finding invariant predictors reduces the effect of spurious features by concentrating models on features that have a causal relationship with the outcome. In this work, we pose such invariant risk minimization as finding the Nash equilibrium of an ensemble game among several environments. By doing so, we develop a simple training algorithm that uses best response dynamics and, in our experiments, yields similar or better empirical accuracy with much lower variance than the challenging bi-level optimization problem of Arjovsky et al. (2019). One key theoretical contribution is showing that the set of Nash equilibria for the proposed game are equivalent to the set of invariant predictors for any finite number of environments, even with nonlinear classifiers and transformations. As a result, our method also retains the generalization guarantees to a large set of environments shown in Arjovsky et al. (2019). The proposed algorithm adds to the collection of successful game-theoretic machine learning algorithms such as generative adversarial networks.
LGMay 2, 2019
Estimating Kullback-Leibler Divergence Using Kernel MachinesKartik Ahuja
Recently, a method called the Mutual Information Neural Estimator (MINE) that uses neural networks has been proposed to estimate mutual information and more generally the Kullback-Leibler (KL) divergence between two distributions. The method uses the Donsker-Varadhan representation to arrive at the estimate of the KL divergence and is better than the existing estimators in terms of scalability and flexibility. The output of MINE algorithm is not guaranteed to be a consistent estimator. We propose a new estimator that instead of searching among functions characterized by neural networks searches the functions in a Reproducing Kernel Hilbert Space. We prove that the proposed estimator is consistent. We carry out simulations and show that when the datasets are small the proposed estimator is more reliable than the MINE estimator and when the datasets are large the performance of the two methods are close.
LGNov 2, 2018
Risk-Stratify: Confident Stratification Of Patients Based On RiskKartik Ahuja, Mihaela van der Schaar
A clinician desires to use a risk-stratification method that achieves confident risk-stratification - the risk estimates of the different patients reflect the true risks with a high probability. This allows him/her to use these risks to make accurate predictions about prognosis and decisions about screening, treatments for the current patient. We develop Risk-stratify - a two phase algorithm that is designed to achieve confident risk-stratification. In the first phase, we grow a tree to partition the covariate space. Each node in the tree is split using statistical tests that determine if the risks of the child nodes are different or not. The choice of the statistical tests depends on whether the data is censored (Log-rank test) or not (U-test). The set of the leaves of the tree form a partition. The risk distribution of patients that belong to a leaf is different from the sibling leaf but not the rest of the leaves. Therefore, some of the leaves that have similar underlying risks are incorrectly specified to have different risks. In the second phase, we develop a novel recursive graph decomposition approach to address this problem. We merge the leaves of the tree that have similar risks to form new leaves that form the final output. We apply Risk-stratify on a cohort of patients (with no history of cardiovascular disease) from UK Biobank and assess their risk for cardiovascular disease. Risk-stratify significantly improves risk-stratification, i.e., a lower fraction of the groups have over/under estimated risks (measured in terms of false discovery rate; 33% reduction) in comparison to state-of-the-art methods for cardiovascular prediction (Random forests, Cox model, etc.). We find that the Cox model significantly over estimates the risk of 21,621 patients out of 216,211 patients. Risk-stratify can accurately categorize 2,987 of these 21,621 patients as low-risk individuals.
LGJun 27, 2018
Optimal Piecewise Local-Linear ApproximationsKartik Ahuja, William Zame, Mihaela van der Schaar
Existing works on "black-box" model interpretation use local-linear approximations to explain the predictions made for each data instance in terms of the importance assigned to the different features for arriving at the prediction. These works provide instancewise explanations and thus give a local view of the model. To be able to trust the model it is important to understand the global model behavior and there are relatively fewer works which do the same. Piecewise local-linear models provide a natural way to extend local-linear models to explain the global behavior of the model. In this work, we provide a dynamic programming based framework to obtain piecewise approximations of the black-box model. We also provide provable fidelity, i.e., how well the explanations reflect the black-box model, guarantees. We carry out simulations on synthetic and real datasets to show the utility of the proposed approach. At the end, we show that the ideas developed for our framework can also be used to address the problem of clustering for one-dimensional data. We give a polynomial time algorithm and prove that it achieves optimal clustering.