Chandan Singh

CL
h-index47
43papers
5,821citations
Novelty53%
AI Score62

43 Papers

CLJun 9, 2022
Beyond the Imitation Game: Quantifying and extrapolating the capabilities of language models

Aarohi Srivastava, Abhinav Rastogi, Abhishek Rao et al. · allen-ai, amazon-science

Language models demonstrate both quantitative improvement and new qualitative capabilities with increasing scale. Despite their potentially transformative impact, these new capabilities are as yet poorly characterized. In order to inform future research, prepare for disruptive new model capabilities, and ameliorate socially harmful effects, it is vital that we understand the present and near-future capabilities and limitations of language models. To address this challenge, we introduce the Beyond the Imitation Game benchmark (BIG-bench). BIG-bench currently consists of 204 tasks, contributed by 450 authors across 132 institutions. Task topics are diverse, drawing problems from linguistics, childhood development, math, common-sense reasoning, biology, physics, social bias, software development, and beyond. BIG-bench focuses on tasks that are believed to be beyond the capabilities of current language models. We evaluate the behavior of OpenAI's GPT models, Google-internal dense transformer architectures, and Switch-style sparse transformers on BIG-bench, across model sizes spanning millions to hundreds of billions of parameters. In addition, a team of human expert raters performed all tasks in order to provide a strong baseline. Findings include: model performance and calibration both improve with scale, but are poor in absolute terms (and when compared with rater performance); performance is remarkably similar across model classes, though with benefits from sparsity; tasks that improve gradually and predictably commonly involve a large knowledge or memorization component, whereas tasks that exhibit "breakthrough" behavior at a critical scale often involve multiple steps or components, or brittle metrics; social bias typically increases with scale in settings with ambiguous context, but this can be improved with prompting.

AISep 23, 2022
Augmenting Interpretable Models with LLMs during Training

Chandan Singh, Armin Askari, Rich Caruana et al. · berkeley

Recent large language models (LLMs) have demonstrated remarkable prediction performance for a growing array of tasks. However, their proliferation into high-stakes domains (e.g. medicine) and compute-limited settings has created a burgeoning need for interpretability and efficiency. We address this need by proposing Augmented Interpretable Models (Aug-imodels), a framework for leveraging the knowledge learned by LLMs to build extremely efficient and interpretable models. Aug-imodels use LLMs during fitting but not during inference, allowing complete transparency and often a speed/memory improvement of greater than 1,000x for inference compared to LLMs. We explore two instantiations of Aug-imodels in natural-language processing: (i) Aug-GAM, which augments a generalized additive model with decoupled embeddings from an LLM and (ii) Aug-Tree, which augments a decision tree with LLM feature expansions. Across a variety of text-classification datasets, both outperform their non-augmented counterparts. Aug-GAM can even outperform much larger models (e.g. a 6-billion parameter GPT-J model), despite having 10,000x fewer parameters and being fully transparent. We further explore Aug-imodels in a natural-language fMRI study, where they generate interesting interpretations from scientific data. All code for using Aug-imodels and reproducing results is made available on Github.

CLNov 3, 2023Code
Tell Your Model Where to Attend: Post-hoc Attention Steering for LLMs

Qingru Zhang, Chandan Singh, Liyuan Liu et al. · gatech

In human-written articles, we often leverage the subtleties of text style, such as bold and italics, to guide the attention of readers. These textual emphases are vital for the readers to grasp the conveyed information. When interacting with large language models (LLMs), we have a similar need -- steering the model to pay closer attention to user-specified information, e.g., an instruction. Existing methods, however, are constrained to process plain text and do not support such a mechanism. This motivates us to introduce PASTA -- Post-hoc Attention STeering Approach, a method that allows LLMs to read text with user-specified emphasis marks. To this end, PASTA identifies a small subset of attention heads and applies precise attention reweighting on them, directing the model attention to user-specified parts. Like prompting, PASTA is applied at inference time and does not require changing any model parameters. Experiments demonstrate that PASTA can substantially enhance an LLM's ability to follow user instructions or integrate new knowledge from user inputs, leading to a significant performance improvement on a variety of tasks, e.g., an average accuracy improvement of 22% for LLAMA-7B. Our code is publicly available at https://github.com/QingruZhang/PASTA .

CLSep 16, 2024Code
Model Tells Itself Where to Attend: Faithfulness Meets Automatic Attention Steering

Qingru Zhang, Xiaodong Yu, Chandan Singh et al. · gatech

Large language models (LLMs) have demonstrated remarkable performance across various real-world tasks. However, they often struggle to fully comprehend and effectively utilize their input contexts, resulting in responses that are unfaithful or hallucinated. This difficulty increases for contexts that are long or contain distracting information, which can divert LLMs from fully capturing essential evidence. To address this issue, many works use prompting to help LLMs utilize contextual information more faithfully. For instance, iterative prompting highlights key information in two steps that first ask the LLM to identify important pieces of context and then derive answers accordingly. However, prompting methods are constrained to highlighting key information implicitly in token space, which is often insufficient to fully steer the model's attention. To improve model faithfulness more reliably, we propose AutoPASTA, a method that automatically identifies key contextual information and explicitly highlights it by steering an LLM's attention scores. Like prompting, AutoPASTA is applied at inference time and does not require changing any model parameters. Our experiments on open-book QA demonstrate that AutoPASTA effectively enables models to grasp essential contextual information, leading to substantially improved model faithfulness and performance, e.g., an average improvement of 7.95% for LLAMA3-70B-Instruct. Code will be publicly available at https://github.com/QingruZhang/AutoPASTA .

LGOct 4, 2022
Explaining Patterns in Data with Language Models via Interpretable Autoprompting

Chandan Singh, John X. Morris, Jyoti Aneja et al. · berkeley

Large language models (LLMs) have displayed an impressive ability to harness natural language to perform complex tasks. In this work, we explore whether we can leverage this learned ability to find and explain patterns in data. Specifically, given a pre-trained LLM and data examples, we introduce interpretable autoprompting (iPrompt), an algorithm that generates a natural-language string explaining the data. iPrompt iteratively alternates between generating explanations with an LLM and reranking them based on their performance when used as a prompt. Experiments on a wide range of datasets, from synthetic mathematics to natural-language understanding, show that iPrompt can yield meaningful insights by accurately finding groundtruth dataset descriptions. Moreover, the prompts produced by iPrompt are simultaneously human-interpretable and highly effective for generalization: on real-world sentiment classification datasets, iPrompt produces prompts that match or even improve upon human-written prompts for GPT-3. Finally, experiments with an fMRI dataset show the potential for iPrompt to aid in scientific discovery. All code for using the methods and data here is made available on Github.

CVSep 22, 2022
Learning Invariant Representations for Equivariant Neural Networks Using Orthogonal Moments

Jaspreet Singh, Chandan Singh · berkeley

The convolutional layers of standard convolutional neural networks (CNNs) are equivariant to translation. However, the convolution and fully-connected layers are not equivariant or invariant to other affine geometric transformations. Recently, a new class of CNNs is proposed in which the conventional layers of CNNs are replaced with equivariant convolution, pooling, and batch-normalization layers. The final classification layer in equivariant neural networks is invariant to different affine geometric transformations such as rotation, reflection and translation, and the scalar value is obtained by either eliminating the spatial dimensions of filter responses using convolution and down-sampling throughout the network or average is taken over the filter responses. In this work, we propose to integrate the orthogonal moments which gives the high-order statistics of the function as an effective means for encoding global invariance with respect to rotation, reflection and translation in fully-connected layers. As a result, the intermediate layers of the network become equivariant while the classification layer becomes invariant. The most widely used Zernike, pseudo-Zernike and orthogonal Fourier-Mellin moments are considered for this purpose. The effectiveness of the proposed work is evaluated by integrating the invariant transition and fully-connected layer in the architecture of group-equivariant CNNs (G-CNNs) on rotated MNIST and CIFAR10 datasets.

CLOct 21, 2023
Tree Prompting: Efficient Task Adaptation without Fine-Tuning

John X. Morris, Chandan Singh, Alexander M. Rush et al. · allen-ai

Prompting language models (LMs) is the main interface for applying them to new tasks. However, for smaller LMs, prompting provides low accuracy compared to gradient-based finetuning. Tree Prompting is an approach to prompting which builds a decision tree of prompts, linking multiple LM calls together to solve a task. At inference time, each call to the LM is determined by efficiently routing the outcome of the previous call using the tree. Experiments on classification datasets show that Tree Prompting improves accuracy over competing methods and is competitive with fine-tuning. We also show that variants of Tree Prompting allow inspection of a model's decision-making process.

LGMay 30, 2022
Group Probability-Weighted Tree Sums for Interpretable Modeling of Heterogeneous Data

Keyan Nasseri, Chandan Singh, James Duncan et al. · berkeley

Machine learning in high-stakes domains, such as healthcare, faces two critical challenges: (1) generalizing to diverse data distributions given limited training data while (2) maintaining interpretability. To address these challenges, we propose an instance-weighted tree-sum method that effectively pools data across diverse groups to output a concise, rule-based model. Given distinct groups of instances in a dataset (e.g., medical patients grouped by age or treatment site), our method first estimates group membership probabilities for each instance. Then, it uses these estimates as instance weights in FIGS (Tan et al. 2022), to grow a set of decision trees whose values sum to the final prediction. We call this new method Group Probability-Weighted Tree Sums (G-FIGS). G-FIGS achieves state-of-the-art prediction performance on important clinical datasets; e.g., holding the level of sensitivity fixed at 92%, G-FIGS increases specificity for identifying cervical spine injury by up to 10% over CART and up to 3% over FIGS alone, with larger gains at higher sensitivity levels. By keeping the total number of rules below 16 in FIGS, the final models remain interpretable, and we find that their rules match medical domain expertise. All code, data, and models are released on Github.

CLFeb 3Code
Test-time Recursive Thinking: Self-Improvement without External Feedback

Yufan Zhuang, Chandan Singh, Liyuan Liu et al.

Modern Large Language Models (LLMs) have shown rapid improvements in reasoning capabilities, driven largely by reinforcement learning (RL) with verifiable rewards. Here, we ask whether these LLMs can self-improve without the need for additional training. We identify two core challenges for such systems: (i) efficiently generating diverse, high-quality candidate solutions, and (ii) reliably selecting correct answers in the absence of ground-truth supervision. To address these challenges, we propose Test-time Recursive Thinking (TRT), an iterative self-improvement framework that conditions generation on rollout-specific strategies, accumulated knowledge, and self-generated verification signals. Using TRT, open-source models reach 100% accuracy on AIME-25/24, and on LiveCodeBench's most difficult problems, closed-source models improve by 10.4-14.8 percentage points without external feedback.

86.6AIApr 13
Sanity Checks for Agentic Data Science

Zachary T. Rewolinski, Austin V. Zane, Hao Huang et al.

Agentic data science (ADS) pipelines have grown rapidly in both capability and adoption, with systems such as OpenAI Codex now able to directly analyze datasets and produce answers to statistical questions. However, these systems can reach falsely optimistic conclusions that are difficult for users to detect. To address this, we propose a pair of lightweight sanity checks grounded in the Predictability-Computability-Stability (PCS) framework for veridical data science. These checks use reasonable perturbations to screen whether an agent can reliably distinguish signal from noise, acting as a falsifiability constraint that can expose affirmative conclusions as unsupported. Together, the two checks characterize the trustworthiness of an ADS output, e.g. whether it has found stable signal, is responding to noise, or is sensitive to incidental aspects of the input. We validate the approach on synthetic data with controlled signal-to-noise ratios, confirming that the sanity checks track ground-truth signal strength. We then demonstrate the checks on 11 real-world datasets using OpenAI Codex, characterizing the trustworthiness of each conclusion and finding that in 6 of the datasets an affirmative conclusion is not well-supported, even though a single ADS run may support one. We further analyze failure modes of ADS systems and find that ADS self-reported confidence is poorly calibrated to the empirical stability of its conclusions.

AIJan 14
Human-AI Co-design for Clinical Prediction Models

Jean Feng, Avni Kothari, Patrick Vossler et al.

Developing safe, effective, and practically useful clinical prediction models (CPMs) traditionally requires iterative collaboration between clinical experts, data scientists, and informaticists. This process refines the often small but critical details of the model building process, such as which features/patients to include and how clinical categories should be defined. However, this traditional collaboration process is extremely time- and resource-intensive, resulting in only a small fraction of CPMs reaching clinical practice. This challenge intensifies when teams attempt to incorporate unstructured clinical notes, which can contain an enormous number of concepts. To address this challenge, we introduce HACHI, an iterative human-in-the-loop framework that uses AI agents to accelerate the development of fully interpretable CPMs by enabling the exploration of concepts in clinical notes. HACHI alternates between (i) an AI agent rapidly exploring and evaluating candidate concepts in clinical notes and (ii) clinical and domain experts providing feedback to improve the CPM learning process. HACHI defines concepts as simple yes-no questions that are used in linear models, allowing the clinical AI team to transparently review, refine, and validate the CPM learned in each round. In two real-world prediction tasks (acute kidney injury and traumatic brain injury), HACHI outperforms existing approaches, surfaces new clinically relevant concepts not included in commonly-used CPMs, and improves model generalizability across clinical sites and time periods. Furthermore, HACHI reveals the critical role of the clinical AI team, such as directing the AI agent to explore concepts that it had not previously considered, adjusting the granularity of concepts it considers, changing the objective function to better align with the clinical objectives, and identifying issues of data bias and leakage.

LGFeb 26
Interpreting and Steering State-Space Models via Activation Subspace Bottlenecks

Vamshi Sunku Mohan, Kaustubh Gupta, Aneesha Das et al.

State-space models (SSMs) have emerged as an efficient strategy for building powerful language models, avoiding the quadratic complexity of computing attention in transformers. Despite their promise, the interpretability and steerability of modern SSMs remain relatively underexplored. We take a major step in this direction by identifying activation subspace bottlenecks in the Mamba family of SSM models using tools from mechanistic interpretability. We then introduce a test-time steering intervention that simply multiplies the activations of the identified bottlenecks by a scalar. Across 5 SSMs and 6 diverse benchmarks, this intervention improves performance by an average of 8.27%, without requiring any task-specific tuning. Finally, we validate that the identified bottlenecks are indeed hindering performance by modifying them to yield an architecture we call Stable-Mamba, which achieves long-context performance gains when retrained from scratch.

CLMar 1, 2024Code
Attribute Structuring Improves LLM-Based Evaluation of Clinical Text Summaries

Zelalem Gero, Chandan Singh, Yiqing Xie et al. · microsoft-research

Summarizing clinical text is crucial in health decision-support and clinical research. Large language models (LLMs) have shown the potential to generate accurate clinical text summaries, but still struggle with issues regarding grounding and evaluation, especially in safety-critical domains such as health. Holistically evaluating text summaries is challenging because they may contain unsubstantiated information. Here, we explore a general mitigation framework using Attribute Structuring (AS), which structures the summary evaluation process. It decomposes the evaluation process into a grounded procedure that uses an LLM for relatively simple structuring and scoring tasks, rather than the full task of holistic summary evaluation. Experiments show that AS consistently improves the correspondence between human annotations and automated metrics in clinical text summarization. Additionally, AS yields interpretations in the form of a short text span corresponding to each output, which enables efficient human auditing, paving the way towards trustworthy evaluation of clinical information in resource-constrained scenarios. We release our code, prompts, and an open-source benchmark at https://github.com/microsoft/attribute-structuring.

CLMay 29, 2025Code
OMNIGUARD: An Efficient Approach for AI Safety Moderation Across Modalities

Sahil Verma, Keegan Hines, Jeff Bilmes et al.

The emerging capabilities of large language models (LLMs) have sparked concerns about their immediate potential for harmful misuse. The core approach to mitigate these concerns is the detection of harmful queries to the model. Current detection approaches are fallible, and are particularly susceptible to attacks that exploit mismatched generalization of model capabilities (e.g., prompts in low-resource languages or prompts provided in non-text modalities such as image and audio). To tackle this challenge, we propose OMNIGUARD, an approach for detecting harmful prompts across languages and modalities. Our approach (i) identifies internal representations of an LLM/MLLM that are aligned across languages or modalities and then (ii) uses them to build a language-agnostic or modality-agnostic classifier for detecting harmful prompts. OMNIGUARD improves harmful prompt classification accuracy by 11.57\% over the strongest baseline in a multilingual setting, by 20.44\% for image-based prompts, and sets a new SOTA for audio-based prompts. By repurposing embeddings computed during generation, OMNIGUARD is also very efficient ($\approx 120 \times$ faster than the next fastest baseline). Code and data are available at: https://github.com/vsahil/OmniGuard.

CLJan 16
Do explanations generalize across large reasoning models?

Koyena Pal, David Bau, Chandan Singh

Large reasoning models (LRMs) produce a textual chain of thought (CoT) in the process of solving a problem, which serves as a potentially powerful tool to understand the problem by surfacing a human-readable, natural-language explanation. However, it is unclear whether these explanations generalize, i.e. whether they capture general patterns about the underlying problem rather than patterns which are esoteric to the LRM. This is a crucial question in understanding or discovering new concepts, e.g. in AI for science. We study this generalization question by evaluating a specific notion of generalizability: whether explanations produced by one LRM induce the same behavior when given to other LRMs. We find that CoT explanations often exhibit this form of generalization (i.e. they increase consistency between LRMs) and that this increased generalization is correlated with human preference rankings and post-training with reinforcement learning. We further analyze the conditions under which explanations yield consistent answers and propose a straightforward, sentence-level ensembling strategy that improves consistency. Taken together, these results prescribe caution when using LRM explanations to yield new insights and outline a framework for characterizing LRM explanation generalization.

65.1LGApr 14
Selecting Feature Interactions for Generalized Additive Models by Distilling Foundation Models

Jingyun Jia, Chandan Singh, Rich Caruana et al.

Identifying meaningful feature interactions is a central challenge in building accurate and interpretable models for tabular data. Generalized additive models (GAMs) have shown great success at modeling tabular data, but often rely on heuristic procedures to select interactions, potentially missing higher-order or context-dependent effects. To meet this challenge, we propose TabDistill, a method that leverages tabular foundation models and post-hoc distillation methods. Our key intuition is that tabular foundation models implicitly learn rich, adaptive feature dependencies through large-scale representation learning. Given a dataset, TabDistill first fits a tabular foundation model to the dataset, and then applies a post-hoc interaction attribution method to extract salient feature interactions from it. We evaluate these interactions by then using them as terms in a GAM. Across tasks, we find that interactions identified by TabDistill lead to consistent improvements in downstream GAMs' predictive performance. Our results suggest that tabular foundation models can serve as effective, data-driven guides for interaction discovery, bridging high-capacity models and interpretable additive frameworks.

GRMar 13, 2025Code
Towards Understanding Graphical Perception in Large Multimodal Models

Kai Zhang, Jianwei Yang, Jeevana Priya Inala et al. · microsoft-research

Despite the promising results of large multimodal models (LMMs) in complex vision-language tasks that require knowledge, reasoning, and perception abilities together, we surprisingly found that these models struggle with simple tasks on infographics that require perception only. As existing benchmarks primarily focus on end tasks that require various abilities, they provide limited, fine-grained insights into the limitations of the models' perception abilities. To address this gap, we leverage the theory of graphical perception, an approach used to study how humans decode visual information encoded on charts and graphs, to develop an evaluation framework for analyzing gaps in LMMs' perception abilities in charts. With automated task generation and response evaluation designs, our framework enables comprehensive and controlled testing of LMMs' graphical perception across diverse chart types, visual elements, and task types. We apply our framework to evaluate and diagnose the perception capabilities of state-of-the-art LMMs at three granularity levels (chart, visual element, and pixel). Our findings underscore several critical limitations of current state-of-the-art LMMs, including GPT-4o: their inability to (1) generalize across chart types, (2) understand fundamental visual elements, and (3) cross reference values within a chart. These insights provide guidance for future improvements in perception abilities of LMMs. The evaluation framework and labeled data are publicly available at https://github.com/microsoft/lmm-graphical-perception.

96.5LGMay 14
Test-Time Learning with an Evolving Library

Weijia Xu, Alessandro Sordoni, Chandan Singh et al.

We introduce EvoLib, a test-time learning framework that enables large language models to accumulate, reuse, and evolve knowledge across problem instances without parameter updates or external supervision. Instead of adapting model parameters, our approach maintains a shared library of knowledge abstractions, including modular skills and reflective insights, automatically extracted from the model's own inference trajectories. To support continual improvement, we introduce a principled weighting and consolidation mechanism that jointly optimizes for immediate utility and long-term value. This allows simple, instance-specific abstractions to evolve into more general and reusable ones over time. Across challenging benchmarks in mathematical reasoning, code generation, and multi-turn agentic environments, EvoLib improves substantially over the top test-time scaling and learning methods without ground-truth feedback.

CLOct 31, 2024Code
Interpretable Next-token Prediction via the Generalized Induction Head

Eunji Kim, Sriya Mantena, Weiwei Yang et al.

While large transformer models excel in predictive performance, their lack of interpretability restricts their usefulness in high-stakes domains. To remedy this, we propose the Generalized Induction-Head Model (GIM), an interpretable model for next-token prediction inspired by the observation of "induction heads" in LLMs. GIM is a retrieval-based module that identifies similar sequences in the input context by combining exact n-gram matching and fuzzy matching based on a neural similarity metric. We evaluate GIM in two settings: language modeling and fMRI response prediction. In language modeling, GIM improves next-token prediction by up to 25%p over interpretable baselines, significantly narrowing the gap with black-box LLMs. In an fMRI setting, GIM improves neural response prediction by 20% and offers insights into the language selectivity of the brain. GIM represents a significant step toward uniting interpretability and performance across domains. The code is available at https://github.com/ejkim47/generalized-induction-head.

CLJan 25, 2024Code
Towards Consistent Natural-Language Explanations via Explanation-Consistency Finetuning

Yanda Chen, Chandan Singh, Xiaodong Liu et al.

Large language models (LLMs) often generate convincing, fluent explanations. However, different from humans, they often generate inconsistent explanations on different inputs. For example, an LLM may generate the explanation "all birds can fly" when answering the question "Can sparrows fly?" but meanwhile answer "no" to the related question "Can penguins fly?". Explanations should be consistent across related examples so that they allow a human to simulate the LLM's decision process on multiple examples. We propose explanation-consistency finetuning (EC-finetuning), a method that adapts LLMs to generate more consistent natural-language explanations on related examples. EC-finetuning involves finetuning LLMs on synthetic data that is carefully constructed to contain consistent explanations. Across a variety of question-answering datasets in various domains, EC-finetuning yields a 10.0% relative explanation consistency improvement on four finetuning datasets, and generalizes to seven out-of-distribution datasets not seen during finetuning (+4.5% relative). Code is available at https://github.com/yandachen/explanation-consistency-finetuning .

LGFeb 2, 2022Code
Hierarchical Shrinkage: improving the accuracy and interpretability of tree-based methods

Abhineet Agarwal, Yan Shuo Tan, Omer Ronen et al.

Tree-based models such as decision trees and random forests (RF) are a cornerstone of modern machine-learning practice. To mitigate overfitting, trees are typically regularized by a variety of techniques that modify their structure (e.g. pruning). We introduce Hierarchical Shrinkage (HS), a post-hoc algorithm that does not modify the tree structure, and instead regularizes the tree by shrinking the prediction over each node towards the sample means of its ancestors. The amount of shrinkage is controlled by a single regularization parameter and the number of data points in each ancestor. Since HS is a post-hoc method, it is extremely fast, compatible with any tree growing algorithm, and can be used synergistically with other regularization techniques. Extensive experiments over a wide variety of real-world datasets show that HS substantially increases the predictive performance of decision trees, even when used in conjunction with other regularization techniques. Moreover, we find that applying HS to each tree in an RF often improves accuracy, as well as its interpretability by simplifying and stabilizing its decision boundaries and SHAP values. We further explain the success of HS in improving prediction performance by showing its equivalence to ridge regression on a (supervised) basis constructed of decision stumps associated with the internal nodes of a tree. All code and models are released in a full-fledged package available on Github (github.com/csinva/imodels)

CLDec 6, 2021Code
NL-Augmenter: A Framework for Task-Sensitive Natural Language Augmentation

Kaustubh D. Dhole, Varun Gangal, Sebastian Gehrmann et al.

Data augmentation is an important component in the robustness evaluation of models in natural language processing (NLP) and in enhancing the diversity of the data they are trained on. In this paper, we present NL-Augmenter, a new participatory Python-based natural language augmentation framework which supports the creation of both transformations (modifications to the data) and filters (data splits according to specific features). We describe the framework and an initial set of 117 transformations and 23 filters for a variety of natural language tasks. We demonstrate the efficacy of NL-Augmenter by using several of its transformations to analyze the robustness of popular natural language models. The infrastructure, datacards and robustness analysis results are available publicly on the NL-Augmenter repository (https://github.com/GEM-benchmark/NL-Augmenter).

MLJul 19, 2021Code
Adaptive wavelet distillation from neural networks through interpretations

Wooseok Ha, Chandan Singh, Francois Lanusse et al.

Recent deep-learning models have achieved impressive prediction performance, but often sacrifice interpretability and computational efficiency. Interpretability is crucial in many disciplines, such as science and medicine, where models must be carefully vetted or where interpretation is the goal itself. Moreover, interpretable models are concise and often yield computational efficiency. Here, we propose adaptive wavelet distillation (AWD), a method which aims to distill information from a trained neural network into a wavelet transform. Specifically, AWD penalizes feature attributions of a neural network in the wavelet domain to learn an effective multi-resolution wavelet transform. The resulting model is highly predictive, concise, computationally efficient, and has properties (such as a multi-scale structure) which make it easy to interpret. In close collaboration with domain experts, we showcase how AWD addresses challenges in two real-world settings: cosmological parameter inference and molecular-partner prediction. In both cases, AWD yields a scientifically interpretable and concise model which gives predictive performance better than state-of-the-art neural networks. Moreover, AWD identifies predictive features that are scientifically meaningful in the context of respective domains. All code and models are released in a full-fledged package available on Github (https://github.com/Yu-Group/adaptive-wavelets).

CVMar 24, 2021Code
Matched sample selection with GANs for mitigating attribute confounding

Chandan Singh, Guha Balakrishnan, Pietro Perona

Measuring biases of vision systems with respect to protected attributes like gender and age is critical as these systems gain widespread use in society. However, significant correlations between attributes in benchmark datasets make it difficult to separate algorithmic bias from dataset bias. To mitigate such attribute confounding during bias analysis, we propose a matching approach that selects a subset of images from the full dataset with balanced attribute distributions across protected attributes. Our matching approach first projects real images onto a generative adversarial network (GAN)'s latent space in a manner that preserves semantic attributes. It then finds image matches in this latent space across a chosen protected attribute, yielding a dataset where semantic and perceptual attributes are balanced across the protected attribute. We validate projection and matching strategies with qualitative, quantitative, and human annotation experiments. We demonstrate our work in the context of gender bias in multiple open-source facial-recognition classifiers and find that bias persists after removing key confounders via matching. Code and documentation to reproduce the results here and apply the methods to new data is available at https://github.com/csinva/matching-with-gans .

CLJan 30, 2024
Rethinking Interpretability in the Era of Large Language Models

Chandan Singh, Jeevana Priya Inala, Michel Galley et al.

Interpretable machine learning has exploded as an area of interest over the last decade, sparked by the rise of increasingly large datasets and deep neural networks. Simultaneously, large language models (LLMs) have demonstrated remarkable capabilities across a wide array of tasks, offering a chance to rethink opportunities in interpretable machine learning. Notably, the capability to explain in natural language allows LLMs to expand the scale and complexity of patterns that can be given to a human. However, these new capabilities raise new challenges, such as hallucinated explanations and immense computational costs. In this position paper, we start by reviewing existing methods to evaluate the emerging field of LLM interpretation (both interpreting LLMs and using LLMs for explanation). We contend that, despite their limitations, LLMs hold the opportunity to redefine interpretability with a more ambitious scope across many applications, including in auditing LLMs themselves. We highlight two emerging research priorities for LLM interpretation: using LLMs to directly analyze new datasets and to generate interactive explanations.

83.7AIMay 5
Agentic-imodels: Evolving agentic interpretability tools via autoresearch

Chandan Singh, Yan Shuo Tan, Weijia Xu et al.

Agentic data science (ADS) systems are rapidly improving their capability to autonomously analyze, fit, and interpret data, potentially moving towards a future where agents conduct the vast majority of data-science work. However, current ADS systems use statistical tools designed to be interpretable by humans, rather than interpretable by agents. To address this, we introduce Agentic-imodels, an agentic autoresearch loop that evolves data-science tools designed to be interpretable by agents. Specifically, it develops a library of scikit-learn-compatible regressors for tabular data that are optimized for both predictive performance and a novel LLM-based interpretability metric. The metric measures a suite of LLM-graded tests that probe whether a fitted model's string representation is "simulatable" by an LLM, i.e. whether the LLM can answer questions about the model's behavior by reading its string output alone. We find that the evolved models jointly improve predictive performance and agent-facing interpretability, generalizing to new datasets and new interpretability tests. Furthermore, these evolved models improve downstream end-to-end ADS, increasing performance for Copilot CLI, Claude Code, and Codex on the BLADE benchmark by up to 73%

CVFeb 14, 2025
Simplifying DINO via Coding Rate Regularization

Ziyang Wu, Jingyuan Zhang, Druv Pai et al.

DINO and DINOv2 are two model families being widely used to learn representations from unlabeled imagery data at large scales. Their learned representations often enable state-of-the-art performance for downstream tasks, such as image classification and segmentation. However, they employ many empirically motivated design choices and their training pipelines are highly complex and unstable -- many hyperparameters need to be carefully tuned to ensure that the representations do not collapse -- which poses considerable difficulty to improving them or adapting them to new domains. In this work, we posit that we can remove most such-motivated idiosyncrasies in the pre-training pipelines, and only need to add an explicit coding rate term in the loss function to avoid collapse of the representations. As a result, we obtain highly simplified variants of the DINO and DINOv2 which we call SimDINO and SimDINOv2, respectively. Remarkably, these simplified models are more robust to different design choices, such as network architecture and hyperparameters, and they learn even higher-quality representations, measured by performance on downstream tasks, offering a Pareto improvement over the corresponding DINO and DINOv2 models. This work highlights the potential of using simplifying design principles to improve the empirical practice of deep learning.

CLMay 20, 2025
Text Generation Beyond Discrete Token Sampling

Yufan Zhuang, Liyuan Liu, Chandan Singh et al.

In standard autoregressive generation, an LLM predicts the next-token distribution, samples a discrete token, and then discards the distribution, passing only the sampled token as new input. To preserve this distribution's rich information, we propose Mixture of Inputs (MoI), a training-free method for autoregressive generation. After generating a token following the standard paradigm, we construct a new input that blends the generated discrete token with the previously discarded token distribution. Specifically, we employ a Bayesian estimation method that treats the token distribution as the prior, the sampled token as the observation, and replaces the conventional one-hot vector with the continuous posterior expectation as the new model input. MoI allows the model to maintain a richer internal representation throughout the generation process, resulting in improved text quality and reasoning capabilities. On mathematical reasoning, code generation, and PhD-level QA tasks, MoI consistently improves performance across multiple models including QwQ-32B, Nemotron-Super-49B, Gemma-3-27B, and DAPO-Qwen-32B, with no additional training and negligible computational overhead.

LGOct 21, 2024
Bayesian Concept Bottleneck Models with LLM Priors

Jean Feng, Avni Kothari, Luke Zier et al.

Concept Bottleneck Models (CBMs) have been proposed as a compromise between white-box and black-box models, aiming to achieve interpretability without sacrificing accuracy. The standard training procedure for CBMs is to predefine a candidate set of human-interpretable concepts, extract their values from the training data, and identify a sparse subset as inputs to a transparent prediction model. However, such approaches are often hampered by the tradeoff between exploring a sufficiently large set of concepts versus controlling the cost of obtaining concept extractions, resulting in a large interpretability-accuracy tradeoff. This work investigates a novel approach that sidesteps these challenges: BC-LLM iteratively searches over a potentially infinite set of concepts within a Bayesian framework, in which Large Language Models (LLMs) serve as both a concept extraction mechanism and prior. Even though LLMs can be miscalibrated and hallucinate, we prove that BC-LLM can provide rigorous statistical inference and uncertainty quantification. Across image, text, and tabular datasets, BC-LLM outperforms interpretable baselines and even black-box models in certain settings, converges more rapidly towards relevant concepts, and is more robust to out-of-distribution samples.

LGFeb 6, 2024
Learning a Decision Tree Algorithm with Transformers

Yufan Zhuang, Liyuan Liu, Chandan Singh et al.

Decision trees are renowned for their ability to achieve high predictive performance while remaining interpretable, especially on tabular data. Traditionally, they are constructed through recursive algorithms, where they partition the data at every node in a tree. However, identifying a good partition is challenging, as decision trees optimized for local segments may not yield global generalization. To address this, we introduce MetaTree, a transformer-based model trained via meta-learning to directly produce strong decision trees. Specifically, we fit both greedy decision trees and globally optimized decision trees on a large number of datasets, and train MetaTree to produce only the trees that achieve strong generalization performance. This training enables MetaTree to emulate these algorithms and intelligently adapt its strategy according to the context, thereby achieving superior generalization performance.

CLMay 30, 2023
Self-Verification Improves Few-Shot Clinical Information Extraction

Zelalem Gero, Chandan Singh, Hao Cheng et al.

Extracting patient information from unstructured text is a critical task in health decision-support and clinical research. Large language models (LLMs) have shown the potential to accelerate clinical curation via few-shot in-context learning, in contrast to supervised learning which requires much more costly human annotations. However, despite drastic advances in modern LLMs such as GPT-4, they still struggle with issues regarding accuracy and interpretability, especially in mission-critical domains such as health. Here, we explore a general mitigation framework using self-verification, which leverages the LLM to provide provenance for its own extraction and check its own outputs. This is made possible by the asymmetry between verification and generation, where the latter is often much easier than the former. Experimental results show that our method consistently improves accuracy for various LLMs in standard clinical information extraction tasks. Additionally, self-verification yields interpretations in the form of a short text span corresponding to each output, which makes it very efficient for human experts to audit the results, paving the way towards trustworthy extraction of clinical information in resource-constrained scenarios. To facilitate future research in this direction, we release our code and prompts.

AIMay 17, 2023
Explaining black box text modules in natural language with language models

Chandan Singh, Aliyah R. Hsu, Richard Antonello et al.

Large language models (LLMs) have demonstrated remarkable prediction performance for a growing array of tasks. However, their rapid proliferation and increasing opaqueness have created a growing need for interpretability. Here, we ask whether we can automatically obtain natural language explanations for black box text modules. A "text module" is any function that maps text to a scalar continuous value, such as a submodule within an LLM or a fitted model of a brain region. "Black box" indicates that we only have access to the module's inputs/outputs. We introduce Summarize and Score (SASC), a method that takes in a text module and returns a natural language explanation of the module's selectivity along with a score for how reliable the explanation is. We study SASC in 3 contexts. First, we evaluate SASC on synthetic modules and find that it often recovers ground truth explanations. Second, we use SASC to explain modules found within a pre-trained BERT model, enabling inspection of the model's internals. Finally, we show that SASC can generate explanations for the response of individual fMRI voxels to language stimuli, with potential applications to fine-grained brain mapping. All code for using SASC and reproducing results is made available on Github.

LGJan 28, 2022
Fast Interpretable Greedy-Tree Sums

Yan Shuo Tan, Chandan Singh, Keyan Nasseri et al.

Modern machine learning has achieved impressive prediction performance, but often sacrifices interpretability, a critical consideration in high-stakes domains such as medicine. In such settings, practitioners often use highly interpretable decision tree models, but these suffer from inductive bias against additive structure. To overcome this bias, we propose Fast Interpretable Greedy-Tree Sums (FIGS), which generalizes the CART algorithm to simultaneously grow a flexible number of trees in summation. By combining logical rules with addition, FIGS is able to adapt to additive structure while remaining highly interpretable. Extensive experiments on real-world datasets show that FIGS achieves state-of-the-art prediction performance. To demonstrate the usefulness of FIGS in high-stakes domains, we adapt FIGS to learn clinical decision instruments (CDIs), which are tools for guiding clinical decision-making. Specifically, we introduce a variant of FIGS known as G-FIGS that accounts for the heterogeneity in medical data. G-FIGS derives CDIs that reflect domain knowledge and enjoy improved specificity (by up to 20% over CART) without sacrificing sensitivity or interpretability. To provide further insight into FIGS, we prove that FIGS learns components of additive models, a property we refer to as disentanglement. Further, we show (under oracle conditions) that unconstrained tree-sum models leverage disentanglement to generalize more efficiently than single decision tree models when fitted to additive regression functions. Finally, to avoid overfitting with an unconstrained number of splits, we develop Bagging-FIGS, an ensemble version of FIGS that borrows the variance reduction techniques of random forests. Bagging-FIGS enjoys competitive performance with random forests and XGBoost on real-world datasets.

MLAug 16, 2021
Interpreting and improving deep-learning models with reality checks

Chandan Singh, Wooseok Ha, Bin Yu

Recent deep-learning models have achieved impressive predictive performance by learning complex functions of many variables, often at the cost of interpretability. This chapter covers recent work aiming to interpret models by attributing importance to features and feature groups for a single prediction. Importantly, the proposed attributions assign importance to interactions between features, in addition to features in isolation. These attributions are shown to yield insights across real-world domains, including bio-imaging, cosmology image and natural-language processing. We then show how these attributions can be used to directly improve the generalization of a neural network or to distill it into a simple model. Throughout the chapter, we emphasize the use of reality checks to scrutinize the proposed interpretation techniques.

LGJun 17, 2020
Revisiting minimum description length complexity in overparameterized models

Raaz Dwivedi, Chandan Singh, Bin Yu et al.

Complexity is a fundamental concept underlying statistical learning theory that aims to inform generalization performance. Parameter count, while successful in low-dimensional settings, is not well-justified for overparameterized settings when the number of parameters is more than the number of training samples. We revisit complexity measures based on Rissanen's principle of minimum description length (MDL) and define a novel MDL-based complexity (MDL-COMP) that remains valid for overparameterized models. MDL-COMP is defined via an optimality criterion over the encodings induced by a good Ridge estimator class. We provide an extensive theoretical characterization of MDL-COMP for linear models and kernel methods and show that it is not just a function of parameter count, but rather a function of the singular values of the design or the kernel matrix and the signal-to-noise ratio. For a linear model with $n$ observations, $d$ parameters, and i.i.d. Gaussian predictors, MDL-COMP scales linearly with $d$ when $d<n$, but the scaling is exponentially smaller -- $\log d$ for $d>n$. For kernel methods, we show that MDL-COMP informs minimax in-sample error, and can decrease as the dimensionality of the input increases. We also prove that MDL-COMP upper bounds the in-sample mean squared error (MSE). Via an array of simulations and real-data experiments, we show that a data-driven Prac-MDL-COMP informs hyper-parameter tuning for optimizing test MSE with ridge regression in limited data settings, sometimes improving upon cross-validation and (always) saving computational costs. Finally, our findings also suggest that the recently observed double decent phenomenons in overparameterized models might be a consequence of the choice of non-ideal estimators.

APMay 16, 2020
Curating a COVID-19 data repository and forecasting county-level death counts in the United States

Nick Altieri, Rebecca L. Barter, James Duncan et al.

As the COVID-19 outbreak evolves, accurate forecasting continues to play an extremely important role in informing policy decisions. In this paper, we present our continuous curation of a large data repository containing COVID-19 information from a range of sources. We use this data to develop predictions and corresponding prediction intervals for the short-term trajectory of COVID-19 cumulative death counts at the county-level in the United States up to two weeks ahead. Using data from January 22 to June 20, 2020, we develop and combine multiple forecasts using ensembling techniques, resulting in an ensemble we refer to as Combined Linear and Exponential Predictors (CLEP). Our individual predictors include county-specific exponential and linear predictors, a shared exponential predictor that pools data together across counties, an expanded shared exponential predictor that uses data from neighboring counties, and a demographics-based shared exponential predictor. We use prediction errors from the past five days to assess the uncertainty of our death predictions, resulting in generally-applicable prediction intervals, Maximum (absolute) Error Prediction Intervals (MEPI). MEPI achieves a coverage rate of more than 94% when averaged across counties for predicting cumulative recorded death counts two weeks in the future. Our forecasts are currently being used by the non-profit organization, Response4Life, to determine the medical supply need for individual hospitals and have directly contributed to the distribution of medical supplies across the country. We hope that our forecasts and data repository at https://covidseverity.com can help guide necessary county-specific decision-making and help counties prepare for their continued fight against COVID-19.

MLMar 4, 2020
Transformation Importance with Applications to Cosmology

Chandan Singh, Wooseok Ha, Francois Lanusse et al.

Machine learning lies at the heart of new possibilities for scientific discovery, knowledge generation, and artificial intelligence. Its potential benefits to these fields requires going beyond predictive accuracy and focusing on interpretability. In particular, many scientific problems require interpretations in a domain-specific interpretable feature space (e.g. the frequency domain) whereas attributions to the raw features (e.g. the pixel space) may be unintelligible or even misleading. To address this challenge, we propose TRIM (TRansformation IMportance), a novel approach which attributes importances to features in a transformed space and can be applied post-hoc to a fully trained model. TRIM is motivated by a cosmological parameter estimation problem using deep neural networks (DNNs) on simulated data, but it is generally applicable across domains/models and can be combined with any local interpretation method. In our cosmology example, combining TRIM with contextual decomposition shows promising results for identifying which frequencies a DNN uses, helping cosmologists to understand and validate that the model learns appropriate physical features rather than simulation artifacts.

LGSep 30, 2019
Interpretations are useful: penalizing explanations to align neural networks with prior knowledge

Laura Rieger, Chandan Singh, W. James Murdoch et al.

For an explanation of a deep learning model to be effective, it must provide both insight into a model and suggest a corresponding action in order to achieve some objective. Too often, the litany of proposed explainable deep learning methods stop at the first step, providing practitioners with insight into a model, but no way to act on it. In this paper, we propose contextual decomposition explanation penalization (CDEP), a method which enables practitioners to leverage existing explanation methods in order to increase the predictive accuracy of deep learning models. In particular, when shown that a model has incorrectly assigned importance to some features, CDEP enables practitioners to correct these errors by directly regularizing the provided explanations. Using explanations provided by contextual decomposition (CD) (Murdoch et al., 2018), we demonstrate the ability of our method to increase performance on an array of toy and real datasets.

MLMay 18, 2019
Disentangled Attribution Curves for Interpreting Random Forests and Boosted Trees

Summer Devlin, Chandan Singh, W. James Murdoch et al.

Tree ensembles, such as random forests and AdaBoost, are ubiquitous machine learning models known for achieving strong predictive performance across a wide variety of domains. However, this strong performance comes at the cost of interpretability (i.e. users are unable to understand the relationships a trained random forest has learned and why it is making its predictions). In particular, it is challenging to understand how the contribution of a particular feature, or group of features, varies as their value changes. To address this, we introduce Disentangled Attribution Curves (DAC), a method to provide interpretations of tree ensemble methods in the form of (multivariate) feature importance curves. For a given variable, or group of variables, DAC plots the importance of a variable(s) as their value changes. We validate DAC on real data by showing that the curves can be used to increase the accuracy of logistic regression while maintaining interpretability, by including DAC as an additional feature. In simulation studies, DAC is shown to out-perform competing methods in the recovery of conditional expectations. Finally, through a case-study on the bike-sharing dataset, we demonstrate the use of DAC to uncover novel insights into a dataset.

MLJan 14, 2019
Interpretable machine learning: definitions, methods, and applications

W. James Murdoch, Chandan Singh, Karl Kumbier et al.

Machine-learning models have demonstrated great success in learning complex patterns that enable them to make predictions about unobserved data. In addition to using models for prediction, the ability to interpret what a model has learned is receiving an increasing amount of attention. However, this increased focus has led to considerable confusion about the notion of interpretability. In particular, it is unclear how the wide array of proposed interpretation methods are related, and what common concepts can be used to evaluate them. We aim to address these concerns by defining interpretability in the context of machine learning and introducing the Predictive, Descriptive, Relevant (PDR) framework for discussing interpretations. The PDR framework provides three overarching desiderata for evaluation: predictive accuracy, descriptive accuracy and relevancy, with relevancy judged relative to a human audience. Moreover, to help manage the deluge of interpretation methods, we introduce a categorization of existing techniques into model-based and post-hoc categories, with sub-groups including sparsity, modularity and simulatability. To demonstrate how practitioners can use the PDR framework to evaluate and understand interpretations, we provide numerous real-world examples. These examples highlight the often under-appreciated role played by human audiences in discussions of interpretability. Finally, based on our framework, we discuss limitations of existing methods and directions for future work. We hope that this work will provide a common vocabulary that will make it easier for both practitioners and researchers to discuss and choose from the full range of interpretation methods.

LGJun 14, 2018
Hierarchical interpretations for neural network predictions

Chandan Singh, W. James Murdoch, Bin Yu

Deep neural networks (DNNs) have achieved impressive predictive performance due to their ability to learn complex, non-linear relationships between variables. However, the inability to effectively visualize these relationships has led to DNNs being characterized as black boxes and consequently limited their applications. To ameliorate this problem, we introduce the use of hierarchical interpretations to explain DNN predictions through our proposed method, agglomerative contextual decomposition (ACD). Given a prediction from a trained DNN, ACD produces a hierarchical clustering of the input features, along with the contribution of each cluster to the final prediction. This hierarchy is optimized to identify clusters of features that the DNN learned are predictive. Using examples from Stanford Sentiment Treebank and ImageNet, we show that ACD is effective at diagnosing incorrect predictions and identifying dataset bias. Through human experiments, we demonstrate that ACD enables users both to identify the more accurate of two DNNs and to better trust a DNN's outputs. We also find that ACD's hierarchy is largely robust to adversarial perturbations, implying that it captures fundamental aspects of the input and ignores spurious noise.

NCSep 13, 2017
A Constrained, Weighted-L1 Minimization Approach for Joint Discovery of Heterogeneous Neural Connectivity Graphs

Chandan Singh, Beilun Wang, Yanjun Qi

Determining functional brain connectivity is crucial to understanding the brain and neural differences underlying disorders such as autism. Recent studies have used Gaussian graphical models to learn brain connectivity via statistical dependencies across brain regions from neuroimaging. However, previous studies often fail to properly incorporate priors tailored to neuroscience, such as preferring shorter connections. To remedy this problem, the paper here introduces a novel, weighted-$\ell_1$, multi-task graphical model (W-SIMULE). This model elegantly incorporates a flexible prior, along with a parallelizable formulation. Additionally, W-SIMULE extends the often-used Gaussian assumption, leading to considerable performance increases. Here, applications to fMRI data show that W-SIMULE succeeds in determining functional connectivity in terms of (1) log-likelihood, (2) finding edges that differentiate groups, and (3) classifying different groups based on their connectivity, achieving 58.6\% accuracy on the ABIDE dataset. Having established W-SIMULE's effectiveness, it links four key areas to autism, all of which are consistent with the literature. Due to its elegant domain adaptivity, W-SIMULE can be readily applied to various data types to effectively estimate connectivity.

CVSep 9, 2017
Large Scale Image Segmentation with Structured Loss based Deep Learning for Connectome Reconstruction

Jan Funke, Fabian David Tschopp, William Grisaitis et al.

We present a method combining affinity prediction with region agglomeration, which improves significantly upon the state of the art of neuron segmentation from electron microscopy (EM) in accuracy and scalability. Our method consists of a 3D U-NET, trained to predict affinities between voxels, followed by iterative region agglomeration. We train using a structured loss based on MALIS, encouraging topologically correct segmentations obtained from affinity thresholding. Our extension consists of two parts: First, we present a quasi-linear method to compute the loss gradient, improving over the original quadratic algorithm. Second, we compute the gradient in two separate passes to avoid spurious gradient contributions in early training stages. Our predictions are accurate enough that simple learning-free percentile-based agglomeration outperforms more involved methods used earlier on inferior predictions. We present results on three diverse EM datasets, achieving relative improvements over previous results of 27%, 15%, and 250%. Our findings suggest that a single method can be applied to both nearly isotropic block-face EM data and anisotropic serial sectioned EM data. The runtime of our method scales linearly with the size of the volume and achieves a throughput of about 2.6 seconds per megavoxel, qualifying our method for the processing of very large datasets.