Abhineet Agarwal

LG
h-index32
11papers
127citations
Novelty54%
AI Score40

11 Papers

MEMar 24, 2023
Synthetic Combinations: A Causal Inference Framework for Combinatorial Interventions

Abhineet Agarwal, Anish Agarwal, Suhas Vijaykumar · berkeley

Consider a setting where there are $N$ heterogeneous units and $p$ interventions. Our goal is to learn unit-specific potential outcomes for any combination of these $p$ interventions, i.e., $N \times 2^p$ causal parameters. Choosing a combination of interventions is a problem that naturally arises in a variety of applications such as factorial design experiments, recommendation engines, combination therapies in medicine, conjoint analysis, etc. Running $N \times 2^p$ experiments to estimate the various parameters is likely expensive and/or infeasible as $N$ and $p$ grow. Further, with observational data there is likely confounding, i.e., whether or not a unit is seen under a combination is correlated with its potential outcome under that combination. To address these challenges, we propose a novel latent factor model that imposes structure across units (i.e., the matrix of potential outcomes is approximately rank $r$), and combinations of interventions (i.e., the coefficients in the Fourier expansion of the potential outcomes is approximately $s$ sparse). We establish identification for all $N \times 2^p$ parameters despite unobserved confounding. We propose an estimation procedure, Synthetic Combinations, and establish it is finite-sample consistent and asymptotically normal under precise conditions on the observation pattern. Our results imply consistent estimation given $\text{poly}(r) \times \left( N + s^2p\right)$ observations, while previous methods have sample complexity scaling as $\min(N \times s^2p, \ \ \text{poly(r)} \times (N + 2^p))$. We use Synthetic Combinations to propose a data-efficient experimental design. Empirically, Synthetic Combinations outperforms competing approaches on a real-world dataset on movie recommendations. Lastly, we extend our analysis to do causal inference where the intervention is a permutation over $p$ items (e.g., rankings).

MEJul 4, 2023
Integrating Random Forests and Generalized Linear Models for Improved Accuracy and Interpretability

Abhineet Agarwal, Ana M. Kenney, Yan Shuo Tan et al. · berkeley

Random forests (RFs) are among the most popular supervised learning algorithms due to their nonlinear flexibility and ease-of-use. However, as black box models, they can only be interpreted via algorithmically-defined feature importance methods, such as Mean Decrease in Impurity (MDI), which have been observed to be highly unstable and have ambiguous scientific meaning. Furthermore, they can perform poorly in the presence of smooth or additive structure. To address this, we reinterpret decision trees and MDI as linear regression and $R^2$ values, respectively, with respect to engineered features associated with the tree's decision splits. This allows us to combine the respective strengths of RFs and generalized linear models in a framework called RF+, which also yields an improved feature importance method we call MDI+. Through extensive data-inspired simulations and real-world datasets, we show that RF+ improves prediction accuracy over RFs and that MDI+ outperforms popular feature importance measures in identifying signal features, often yielding more than a 10% improvement over its closest competitor. In case studies on drug response prediction and breast cancer subtyping, we further show that MDI+ extracts well-established genes with significantly greater stability compared to existing feature importance measures.

CLFeb 21, 2024Code
ED-Copilot: Reduce Emergency Department Wait Time with Language Model Diagnostic Assistance

Liwen Sun, Abhineet Agarwal, Aaron Kornblith et al. · berkeley

In the emergency department (ED), patients undergo triage and multiple laboratory tests before diagnosis. This time-consuming process causes ED crowding which impacts patient mortality, medical errors, staff burnout, etc. This work proposes (time) cost-effective diagnostic assistance that leverages artificial intelligence systems to help ED clinicians make efficient and accurate diagnoses. In collaboration with ED clinicians, we use public patient data to curate MIMIC-ED-Assist, a benchmark for AI systems to suggest laboratory tests that minimize wait time while accurately predicting critical outcomes such as death. With MIMIC-ED-Assist, we develop ED-Copilot which sequentially suggests patient-specific laboratory tests and makes diagnostic predictions. ED-Copilot employs a pre-trained bio-medical language model to encode patient information and uses reinforcement learning to minimize ED wait time and maximize prediction accuracy. On MIMIC-ED-Assist, ED-Copilot improves prediction accuracy over baselines while halving average wait time from four hours to two hours. ED-Copilot can also effectively personalize treatment recommendations based on patient severity, further highlighting its potential as a diagnostic assistant. Since MIMIC-ED-Assist is a retrospective benchmark, ED-Copilot is restricted to recommend only observed tests. We show ED-Copilot achieves competitive performance without this restriction as the maximum allowed time increases. Our code is available at https://github.com/cxcscmu/ED-Copilot.

LGFeb 2, 2022Code
Hierarchical Shrinkage: improving the accuracy and interpretability of tree-based methods

Abhineet Agarwal, Yan Shuo Tan, Omer Ronen et al.

Tree-based models such as decision trees and random forests (RF) are a cornerstone of modern machine-learning practice. To mitigate overfitting, trees are typically regularized by a variety of techniques that modify their structure (e.g. pruning). We introduce Hierarchical Shrinkage (HS), a post-hoc algorithm that does not modify the tree structure, and instead regularizes the tree by shrinking the prediction over each node towards the sample means of its ancestors. The amount of shrinkage is controlled by a single regularization parameter and the number of data points in each ancestor. Since HS is a post-hoc method, it is extremely fast, compatible with any tree growing algorithm, and can be used synergistically with other regularization techniques. Extensive experiments over a wide variety of real-world datasets show that HS substantially increases the predictive performance of decision trees, even when used in conjunction with other regularization techniques. Moreover, we find that applying HS to each tree in an RF often improves accuracy, as well as its interpretability by simplifying and stabilizing its decision boundaries and SHAP values. We further explain the success of HS in improving prediction performance by showing its equivalence to ridge regression on a (supervised) basis constructed of decision stumps associated with the internal nodes of a tree. All code and models are released in a full-fledged package available on Github (github.com/csinva/imodels)

MLMay 13, 2025
PCS-UQ: Uncertainty Quantification via the Predictability-Computability-Stability Framework

Abhineet Agarwal, Michael Xiao, Rebecca Barter et al. · berkeley

As machine learning (ML) models are increasingly deployed in high-stakes domains, trustworthy uncertainty quantification (UQ) is critical for ensuring the safety and reliability of these models. Traditional UQ methods rely on specifying a true generative model and are not robust to misspecification. On the other hand, conformal inference allows for arbitrary ML models but does not consider model selection, which leads to large interval sizes. We tackle these drawbacks by proposing a UQ method based on the predictability, computability, and stability (PCS) framework for veridical data science proposed by Yu and Kumbier. Specifically, PCS-UQ addresses model selection by using a prediction check to screen out unsuitable models. PCS-UQ then fits these screened algorithms across multiple bootstraps to assess inter-sample variability and algorithmic instability, enabling more reliable uncertainty estimates. Further, we propose a novel calibration scheme that improves local adaptivity of our prediction sets. Experiments across $17$ regression and $6$ classification datasets show that PCS-UQ achieves the desired coverage and reduces width over conformal approaches by $\approx 20\%$. Further, our local analysis shows PCS-UQ often achieves target coverage across subgroups while conformal methods fail to do so. For large deep-learning models, we propose computationally efficient approximation schemes that avoid the expensive multiple bootstrap trainings of PCS-UQ. Across three computer vision benchmarks, PCS-UQ reduces prediction set size over conformal methods by $20\%$. Theoretically, we show a modified PCS-UQ algorithm is a form of split conformal inference and achieves the desired coverage with exchangeable data.

LGFeb 19, 2025
SPEX: Scaling Feature Interaction Explanations for LLMs

Justin Singh Kang, Landon Butler, Abhineet Agarwal et al. · berkeley

Large language models (LLMs) have revolutionized machine learning due to their ability to capture complex interactions between input features. Popular post-hoc explanation methods like SHAP provide marginal feature attributions, while their extensions to interaction importances only scale to small input lengths ($\approx 20$). We propose Spectral Explainer (SPEX), a model-agnostic interaction attribution algorithm that efficiently scales to large input lengths ($\approx 1000)$. SPEX exploits underlying natural sparsity among interactions -- common in real-world data -- and applies a sparse Fourier transform using a channel decoding algorithm to efficiently identify important interactions. We perform experiments across three difficult long-context datasets that require LLMs to utilize interactions between inputs to complete the task. For large inputs, SPEX outperforms marginal attribution methods by up to 20% in terms of faithfully reconstructing LLM outputs. Further, SPEX successfully identifies key features and interactions that strongly influence model output. For one of our datasets, HotpotQA, SPEX provides interactions that align with human annotations. Finally, we use our model-agnostic approach to generate explanations to demonstrate abstract reasoning in closed-source LLMs (GPT-4o mini) and compositional reasoning in vision-language models.

LGJun 10, 2025
Local MDI+: Local Feature Importances for Tree-Based Models

Zhongyuan Liang, Zachary T. Rewolinski, Abhineet Agarwal et al.

Tree-based ensembles such as random forests remain the go-to for tabular data over deep learning models due to their prediction performance and computational efficiency. These advantages have led to their widespread deployment in high-stakes domains, where interpretability is essential for ensuring trustworthy predictions. This has motivated the development of popular local (i.e. sample-specific) feature importance (LFI) methods such as LIME and TreeSHAP. However, these approaches rely on approximations that ignore the model's internal structure and instead depend on potentially unstable perturbations. These issues are addressed in the global setting by MDI+, a feature importance method which exploits an equivalence between decision trees and linear models on a transformed node basis. However, the global MDI+ scores are not able to explain predictions when faced with heterogeneous individual characteristics. To address this gap, we propose Local MDI+ (LMDI+), a novel extension of the MDI+ framework to the sample specific setting. LMDI+ outperforms existing baselines LIME and TreeSHAP in identifying instance-specific signal features, averaging a 10% improvement in downstream task performance across twelve real-world benchmark datasets. It further demonstrates greater stability by consistently producing similar instance-level feature importance rankings across multiple random forest fits. Finally, LMDI+ enables local interpretability use cases, including the identification of closer counterfactuals and the discovery of homogeneous subgroups.

LGMay 23, 2025
ProxySPEX: Inference-Efficient Interpretability via Sparse Feature Interactions in LLMs

Landon Butler, Abhineet Agarwal, Justin Singh Kang et al. · berkeley

Large Language Models (LLMs) have achieved remarkable performance by capturing complex interactions between input features. To identify these interactions, most existing approaches require enumerating all possible combinations of features up to a given order, causing them to scale poorly with the number of inputs $n$. Recently, Kang et al. (2025) proposed SPEX, an information-theoretic approach that uses interaction sparsity to scale to $n \approx 10^3$ features. SPEX greatly improves upon prior methods but requires tens of thousands of model inferences, which can be prohibitive for large models. In this paper, we observe that LLM feature interactions are often hierarchical -- higher-order interactions are accompanied by their lower-order subsets -- which enables more efficient discovery. To exploit this hierarchy, we propose ProxySPEX, an interaction attribution algorithm that first fits gradient boosted trees to masked LLM outputs and then extracts the important interactions. Experiments across four challenging high-dimensional datasets show that ProxySPEX more faithfully reconstructs LLM outputs by 20% over marginal attribution approaches while using $10\times$ fewer inferences than SPEX. By accounting for interactions, ProxySPEX efficiently identifies the most influential features, providing a scalable approximation of their Shapley values. Further, we apply ProxySPEX to two interpretability tasks. Data attribution, where we identify interactions among CIFAR-10 training samples that influence test predictions, and mechanistic interpretability, where we uncover interactions between attention heads, both within and across layers, on a question-answering task.

LGMar 9, 2025
Adaptive Test-Time Intervention for Concept Bottleneck Models

Matthew Shen, Aliyah Hsu, Abhineet Agarwal et al.

Concept bottleneck models (CBM) aim to improve model interpretability by predicting human level "concepts" in a bottleneck within a deep learning model architecture. However, how the predicted concepts are used in predicting the target still either remains black-box or is simplified to maintain interpretability at the cost of prediction performance. We propose to use Fast Interpretable Greedy Sum-Trees (FIGS) to obtain Binary Distillation (BD). This new method, called FIGS-BD, distills a binary-augmented concept-to-target portion of the CBM into an interpretable tree-based model, while maintaining the competitive prediction performance of the CBM teacher. FIGS-BD can be used in downstream tasks to explain and decompose CBM predictions into interpretable binary-concept-interaction attributions and guide adaptive test-time intervention. Across 4 datasets, we demonstrate that our adaptive test-time intervention identifies key concepts that significantly improve performance for realistic human-in-the-loop settings that only allow for limited concept interventions.

LGJan 28, 2022
Fast Interpretable Greedy-Tree Sums

Yan Shuo Tan, Chandan Singh, Keyan Nasseri et al.

Modern machine learning has achieved impressive prediction performance, but often sacrifices interpretability, a critical consideration in high-stakes domains such as medicine. In such settings, practitioners often use highly interpretable decision tree models, but these suffer from inductive bias against additive structure. To overcome this bias, we propose Fast Interpretable Greedy-Tree Sums (FIGS), which generalizes the CART algorithm to simultaneously grow a flexible number of trees in summation. By combining logical rules with addition, FIGS is able to adapt to additive structure while remaining highly interpretable. Extensive experiments on real-world datasets show that FIGS achieves state-of-the-art prediction performance. To demonstrate the usefulness of FIGS in high-stakes domains, we adapt FIGS to learn clinical decision instruments (CDIs), which are tools for guiding clinical decision-making. Specifically, we introduce a variant of FIGS known as G-FIGS that accounts for the heterogeneity in medical data. G-FIGS derives CDIs that reflect domain knowledge and enjoy improved specificity (by up to 20% over CART) without sacrificing sensitivity or interpretability. To provide further insight into FIGS, we prove that FIGS learns components of additive models, a property we refer to as disentanglement. Further, we show (under oracle conditions) that unconstrained tree-sum models leverage disentanglement to generalize more efficiently than single decision tree models when fitted to additive regression functions. Finally, to avoid overfitting with an unconstrained number of splits, we develop Bagging-FIGS, an ensemble version of FIGS that borrows the variance reduction techniques of random forests. Bagging-FIGS enjoys competitive performance with random forests and XGBoost on real-world datasets.

MLOct 18, 2021
A cautionary tale on fitting decision trees to data from additive models: generalization lower bounds

Yan Shuo Tan, Abhineet Agarwal, Bin Yu

Decision trees are important both as interpretable models amenable to high-stakes decision-making, and as building blocks of ensemble methods such as random forests and gradient boosting. Their statistical properties, however, are not well understood. The most cited prior works have focused on deriving pointwise consistency guarantees for CART in a classical nonparametric regression setting. We take a different approach, and advocate studying the generalization performance of decision trees with respect to different generative regression models. This allows us to elicit their inductive bias, that is, the assumptions the algorithms make (or do not make) to generalize to new data, thereby guiding practitioners on when and how to apply these methods. In this paper, we focus on sparse additive generative models, which have both low statistical complexity and some nonparametric flexibility. We prove a sharp squared error generalization lower bound for a large class of decision tree algorithms fitted to sparse additive models with $C^1$ component functions. This bound is surprisingly much worse than the minimax rate for estimating such sparse additive models. The inefficiency is due not to greediness, but to the loss in power for detecting global structure when we average responses solely over each leaf, an observation that suggests opportunities to improve tree-based algorithms, for example, by hierarchical shrinkage. To prove these bounds, we develop new technical machinery, establishing a novel connection between decision tree estimation and rate-distortion theory, a sub-field of information theory.