Rich Caruana

LG
h-index38
42papers
5,887citations
Novelty48%
AI Score49

42 Papers

LGFeb 27, 2023Code
GAM Coach: Towards Interactive and User-centered Algorithmic Recourse

Zijie J. Wang, Jennifer Wortman Vaughan, Rich Caruana et al. · gatech, microsoft-research

Machine learning (ML) recourse techniques are increasingly used in high-stakes domains, providing end users with actions to alter ML predictions, but they assume ML developers understand what input variables can be changed. However, a recourse plan's actionability is subjective and unlikely to match developers' expectations completely. We present GAM Coach, a novel open-source system that adapts integer linear programming to generate customizable counterfactual explanations for Generalized Additive Models (GAMs), and leverages interactive visualizations to enable end users to iteratively generate recourse plans meeting their needs. A quantitative user study with 41 participants shows our tool is usable and useful, and users prefer personalized recourse plans over generic plans. Through a log analysis, we explore how users discover satisfactory recourse plans, and provide empirical evidence that transparency can lead to more opportunities for everyday users to discover counterintuitive patterns in ML models. GAM Coach is available at: https://poloclub.github.io/gam-coach/.

AISep 23, 2022
Augmenting Interpretable Models with LLMs during Training

Chandan Singh, Armin Askari, Rich Caruana et al. · berkeley

Recent large language models (LLMs) have demonstrated remarkable prediction performance for a growing array of tasks. However, their proliferation into high-stakes domains (e.g. medicine) and compute-limited settings has created a burgeoning need for interpretability and efficiency. We address this need by proposing Augmented Interpretable Models (Aug-imodels), a framework for leveraging the knowledge learned by LLMs to build extremely efficient and interpretable models. Aug-imodels use LLMs during fitting but not during inference, allowing complete transparency and often a speed/memory improvement of greater than 1,000x for inference compared to LLMs. We explore two instantiations of Aug-imodels in natural-language processing: (i) Aug-GAM, which augments a generalized additive model with decoupled embeddings from an LLM and (ii) Aug-Tree, which augments a decision tree with LLM feature expansions. Across a variety of text-classification datasets, both outperform their non-augmented counterparts. Aug-GAM can even outperform much larger models (e.g. a 6-billion parameter GPT-J model), despite having 10,000x fewer parameters and being fully transparent. We further explore Aug-imodels in a natural-language fMRI study, where they generate interesting interpretations from scientific data. All code for using Aug-imodels and reproducing results is made available on Github.

LGJun 30, 2022
Interpretability, Then What? Editing Machine Learning Models to Reflect Human Knowledge and Values

Zijie J. Wang, Alex Kale, Harsha Nori et al. · gatech, microsoft-research

Machine learning (ML) interpretability techniques can reveal undesirable patterns in data that models exploit to make predictions--potentially causing harms once deployed. However, how to take action to address these patterns is not always clear. In a collaboration between ML and human-computer interaction researchers, physicians, and data scientists, we develop GAM Changer, the first interactive system to help domain experts and data scientists easily and responsibly edit Generalized Additive Models (GAMs) and fix problematic patterns. With novel interaction techniques, our tool puts interpretability into action--empowering users to analyze, validate, and align model behaviors with their knowledge and values. Physicians have started to use our tool to investigate and fix pneumonia and sepsis risk prediction models, and an evaluation with 7 data scientists working in diverse domains highlights that our tool is easy to use, meets their model editing needs, and fits into their current workflows. Built with modern web technologies, our tool runs locally in users' web browsers or computational notebooks, lowering the barrier to use. GAM Changer is available at the following public demo link: https://interpret.ml/gam-changer.

MLAug 2, 2023Code
LLMs Understand Glass-Box Models, Discover Surprises, and Suggest Repairs

Benjamin J. Lengerich, Sebastian Bordt, Harsha Nori et al.

We show that large language models (LLMs) are remarkably good at working with interpretable models that decompose complex outcomes into univariate graph-represented components. By adopting a hierarchical approach to reasoning, LLMs can provide comprehensive model-level summaries without ever requiring the entire model to fit in context. This approach enables LLMs to apply their extensive background knowledge to automate common tasks in data science such as detecting anomalies that contradict prior knowledge, describing potential reasons for the anomalies, and suggesting repairs that would remove the anomalies. We use multiple examples in healthcare to demonstrate the utility of these new capabilities of LLMs, with particular emphasis on Generalized Additive Models (GAMs). Finally, we present the package $\texttt{TalkToEBM}$ as an open-source LLM-GAM interface.

LGApr 23, 2023
Missing Values and Imputation in Healthcare Data: Can Interpretable Machine Learning Help?

Zhi Chen, Sarah Tan, Urszula Chajewska et al.

Missing values are a fundamental problem in data science. Many datasets have missing values that must be properly handled because the way missing values are treated can have large impact on the resulting machine learning model. In medical applications, the consequences may affect healthcare decisions. There are many methods in the literature for dealing with missing values, including state-of-the-art methods which often depend on black-box models for imputation. In this work, we show how recent advances in interpretable machine learning provide a new perspective for understanding and tackling the missing value problem. We propose methods based on high-accuracy glass-box Explainable Boosting Machines (EBMs) that can help users (1) gain new insights on missingness mechanisms and better understand the causes of missingness, and (2) detect -- or even alleviate -- potential risks introduced by imputation algorithms. Experiments on real-world medical datasets illustrate the effectiveness of the proposed methods.

LGOct 16, 2023
Interpretable Predictive Models to Understand Risk Factors for Maternal and Fetal Outcomes

Tomas M. Bosschieter, Zifei Xu, Hui Lan et al.

Although most pregnancies result in a good outcome, complications are not uncommon and can be associated with serious implications for mothers and babies. Predictive modeling has the potential to improve outcomes through better understanding of risk factors, heightened surveillance for high risk patients, and more timely and appropriate interventions, thereby helping obstetricians deliver better care. We identify and study the most important risk factors for four types of pregnancy complications: (i) severe maternal morbidity, (ii) shoulder dystocia, (iii) preterm preeclampsia, and (iv) antepartum stillbirth. We use an Explainable Boosting Machine (EBM), a high-accuracy glass-box learning method, for prediction and identification of important risk factors. We undertake external validation and perform an extensive robustness analysis of the EBM models. EBMs match the accuracy of other black-box ML methods such as deep neural networks and random forests, and outperform logistic regression, while being more interpretable. EBMs prove to be robust. The interpretability of the EBM models reveals surprising insights into the features contributing to risk (e.g. maternal height is the second most important feature for shoulder dystocia) and may have potential for clinical application in the prediction and prevention of serious complications in pregnancy.

LGNov 15, 2022
Estimating Discontinuous Time-Varying Risk Factors and Treatment Benefits for COVID-19 with Interpretable ML

Benjamin Lengerich, Mark E. Nunnally, Yin Aphinyanaphongs et al.

Treatment protocols, disease understanding, and viral characteristics changed over the course of the COVID-19 pandemic; as a result, the risks associated with patient comorbidities and biomarkers also changed. We add to the conversation regarding inflammation, hemostasis and vascular function in COVID-19 by performing a time-varying observational analysis of over 4000 patients hospitalized for COVID-19 in a New York City hospital system from March 2020 to August 2021. To perform this analysis, we apply tree-based generalized additive models with temporal interactions which recover discontinuous risk changes caused by discrete protocols changes. We find that the biomarkers of thrombosis increasingly predicted mortality from March 2020 to August 2021, while the association between biomarkers of inflammation and thrombosis weakened. Beyond COVID-19, this presents a straightforward methodology to estimate unknown and discontinuous time-varying effects.

LGJul 12, 2022
Using Interpretable Machine Learning to Predict Maternal and Fetal Outcomes

Tomas M. Bosschieter, Zifei Xu, Hui Lan et al.

Most pregnancies and births result in a good outcome, but complications are not uncommon and when they do occur, they can be associated with serious implications for mothers and babies. Predictive modeling has the potential to improve outcomes through better understanding of risk factors, heightened surveillance, and more timely and appropriate interventions, thereby helping obstetricians deliver better care. For three types of complications we identify and study the most important risk factors using Explainable Boosting Machine (EBM), a glass box model, in order to gain intelligibility: (i) Severe Maternal Morbidity (SMM), (ii) shoulder dystocia, and (iii) preterm preeclampsia. While using the interpretability of EBM's to reveal surprising insights into the features contributing to risk, our experiments show EBMs match the accuracy of other black-box ML methods such as deep neural nets and random forests.

LGApr 9, 2024Code
Elephants Never Forget: Memorization and Learning of Tabular Data in Large Language Models

Sebastian Bordt, Harsha Nori, Vanessa Rodrigues et al.

While many have shown how Large Language Models (LLMs) can be applied to a diverse set of tasks, the critical issues of data contamination and memorization are often glossed over. In this work, we address this concern for tabular data. Specifically, we introduce a variety of different techniques to assess whether a language model has seen a tabular dataset during training. This investigation reveals that LLMs have memorized many popular tabular datasets verbatim. We then compare the few-shot learning performance of LLMs on datasets that were seen during training to the performance on datasets released after training. We find that LLMs perform better on datasets seen during training, indicating that memorization leads to overfitting. At the same time, LLMs show non-trivial performance on novel datasets and are surprisingly robust to data transformations. We then investigate the in-context statistical learning abilities of LLMs. While LLMs are significantly better than random at solving statistical classification problems, the sample efficiency of few-shot learning lags behind traditional statistical learning algorithms, especially as the dimension of the problem increases. This suggests that much of the observed few-shot performance on novel real-world datasets is due to the LLM's world knowledge. Overall, our results highlight the importance of testing whether an LLM has seen an evaluation dataset during pre-training. We release the https://github.com/interpretml/LLM-Tabular-Memorization-Checker Python package to test LLMs for memorization of tabular datasets.

AIJun 29, 2023
Diagnosis Uncertain Models For Medical Risk Prediction

Alexander Peysakhovich, Rich Caruana, Yin Aphinyanaphongs

We consider a patient risk models which has access to patient features such as vital signs, lab values, and prior history but does not have access to a patient's diagnosis. For example, this occurs in a model deployed at intake time for triage purposes. We show that such `all-cause' risk models have good generalization across diagnoses but have a predictable failure mode. When the same lab/vital/history profiles can result from diagnoses with different risk profiles (e.g. E.coli vs. MRSA) the risk estimate is a probability weighted average of these two profiles. This leads to an under-estimation of risk for rare but highly risky diagnoses. We propose a fix for this problem by explicitly modeling the uncertainty in risk prediction coming from uncertainty in patient diagnoses. This gives practitioners an interpretable way to understand patient risk beyond a single risk number.

LGMar 11, 2024Code
Elephants Never Forget: Testing Language Models for Memorization of Tabular Data

Sebastian Bordt, Harsha Nori, Rich Caruana

While many have shown how Large Language Models (LLMs) can be applied to a diverse set of tasks, the critical issues of data contamination and memorization are often glossed over. In this work, we address this concern for tabular data. Starting with simple qualitative tests for whether an LLM knows the names and values of features, we introduce a variety of different techniques to assess the degrees of contamination, including statistical tests for conditional distribution modeling and four tests that identify memorization. Our investigation reveals that LLMs are pre-trained on many popular tabular datasets. This exposure can lead to invalid performance evaluation on downstream tasks because the LLMs have, in effect, been fit to the test set. Interestingly, we also identify a regime where the language model reproduces important statistics of the data, but fails to reproduce the dataset verbatim. On these datasets, although seen during training, good performance on downstream tasks might not be due to overfitting. Our findings underscore the need for ensuring data integrity in machine learning tasks with LLMs. To facilitate future research, we release an open-source tool that can perform various tests for memorization \url{https://github.com/interpretml/LLM-Tabular-Memorization-Checker}.

LGFeb 22, 2024Code
Data Science with LLMs and Interpretable Models

Sebastian Bordt, Ben Lengerich, Harsha Nori et al.

Recent years have seen important advances in the building of interpretable models, machine learning models that are designed to be easily understood by humans. In this work, we show that large language models (LLMs) are remarkably good at working with interpretable models, too. In particular, we show that LLMs can describe, interpret, and debug Generalized Additive Models (GAMs). Combining the flexibility of LLMs with the breadth of statistical patterns accurately described by GAMs enables dataset summarization, question answering, and model critique. LLMs can also improve the interaction between domain experts and interpretable models, and generate hypotheses about the underlying phenomenon. We release \url{https://github.com/interpretml/TalkToEBM} as an open-source LLM-GAM interface.

64.9LGApr 14
Selecting Feature Interactions for Generalized Additive Models by Distilling Foundation Models

Jingyun Jia, Chandan Singh, Rich Caruana et al.

Identifying meaningful feature interactions is a central challenge in building accurate and interpretable models for tabular data. Generalized additive models (GAMs) have shown great success at modeling tabular data, but often rely on heuristic procedures to select interactions, potentially missing higher-order or context-dependent effects. To meet this challenge, we propose TabDistill, a method that leverages tabular foundation models and post-hoc distillation methods. Our key intuition is that tabular foundation models implicitly learn rich, adaptive feature dependencies through large-scale representation learning. Given a dataset, TabDistill first fits a tabular foundation model to the dataset, and then applies a post-hoc interaction attribution method to extract salient feature interactions from it. We evaluate these interactions by then using them as terms in a GAM. Across tasks, we find that interactions identified by TabDistill lead to consistent improvements in downstream GAMs' predictive performance. Our results suggest that tabular foundation models can serve as effective, data-driven guides for interaction discovery, bridging high-capacity models and interpretable additive frameworks.

LGDec 6, 2021Code
GAM Changer: Editing Generalized Additive Models with Interactive Visualization

Zijie J. Wang, Alex Kale, Harsha Nori et al.

Recent strides in interpretable machine learning (ML) research reveal that models exploit undesirable patterns in the data to make predictions, which potentially causes harms in deployment. However, it is unclear how we can fix these models. We present our ongoing work, GAM Changer, an open-source interactive system to help data scientists and domain experts easily and responsibly edit their Generalized Additive Models (GAMs). With novel visualization techniques, our tool puts interpretability into action -- empowering human users to analyze, validate, and align model behaviors with their knowledge and values. Built using modern web technologies, our tool runs locally in users' computational notebooks or web browsers without requiring extra compute resources, lowering the barrier to creating more responsible ML models. GAM Changer is available at https://interpret.ml/gam-changer.

LGSep 19, 2019Code
InterpretML: A Unified Framework for Machine Learning Interpretability

Harsha Nori, Samuel Jenkins, Paul Koch et al.

InterpretML is an open-source Python package which exposes machine learning interpretability algorithms to practitioners and researchers. InterpretML exposes two types of interpretability - glassbox models, which are machine learning models designed for interpretability (ex: linear models, rule lists, generalized additive models), and blackbox explainability techniques for explaining existing systems (ex: Partial Dependence, LIME). The package enables practitioners to easily compare interpretability algorithms by exposing multiple methods under a unified API, and by having a built-in, extensible visualization platform. InterpretML also includes the first implementation of the Explainable Boosting Machine, a powerful, interpretable, glassbox model that can be as accurate as many blackbox models. The MIT licensed source code can be downloaded from github.com/microsoft/interpret.

CLJan 30, 2024
Rethinking Interpretability in the Era of Large Language Models

Chandan Singh, Jeevana Priya Inala, Michel Galley et al.

Interpretable machine learning has exploded as an area of interest over the last decade, sparked by the rise of increasingly large datasets and deep neural networks. Simultaneously, large language models (LLMs) have demonstrated remarkable capabilities across a wide array of tasks, offering a chance to rethink opportunities in interpretable machine learning. Notably, the capability to explain in natural language allows LLMs to expand the scale and complexity of patterns that can be given to a human. However, these new capabilities raise new challenges, such as hallucinated explanations and immense computational costs. In this position paper, we start by reviewing existing methods to evaluate the emerging field of LLM interpretation (both interpreting LLMs and using LLMs for explanation). We contend that, despite their limitations, LLMs hold the opportunity to redefine interpretability with a more ambitious scope across many applications, including in auditing LLMs themselves. We highlight two emerging research priorities for LLM interpretation: using LLMs to directly analyze new datasets and to generate interactive explanations.

LGNov 22, 2023
Explaining high-dimensional text classifiers

Odelia Melamed, Rich Caruana

Explainability has become a valuable tool in the last few years, helping humans better understand AI-guided decisions. However, the classic explainability tools are sometimes quite limited when considering high-dimensional inputs and neural network classifiers. We present a new explainability method using theoretically proven high-dimensional properties in neural network classifiers. We present two usages of it: 1) On the classical sentiment analysis task for the IMDB reviews dataset, and 2) our Malware-Detection task for our PowerShell scripts dataset.

LGJun 24, 2025
The Most Important Features in Generalized Additive Models Might Be Groups of Features

Tomas M. Bosschieter, Luis Franca, Jessica Wolk et al.

While analyzing the importance of features has become ubiquitous in interpretable machine learning, the joint signal from a group of related features is sometimes overlooked or inadvertently excluded. Neglecting the joint signal could bypass a critical insight: in many instances, the most significant predictors are not isolated features, but rather the combined effect of groups of features. This can be especially problematic for datasets that contain natural groupings of features, including multimodal datasets. This paper introduces a novel approach to determine the importance of a group of features for Generalized Additive Models (GAMs) that is efficient, requires no model retraining, allows defining groups posthoc, permits overlapping groups, and remains meaningful in high-dimensional settings. Moreover, this definition offers a parallel with explained variation in statistics. We showcase properties of our method on three synthetic experiments that illustrate the behavior of group importance across various data regimes. We then demonstrate the importance of groups of features in identifying depressive symptoms from a multimodal neuroscience dataset, and study the importance of social determinants of health after total hip arthroplasty. These two case studies reveal that analyzing group importance offers a more accurate, holistic view of the medical issues compared to a single-feature analysis.

CVMay 25, 2023
Extending Explainable Boosting Machines to Scientific Image Data

Daniel Schug, Sai Yerramreddy, Rich Caruana et al.

As the deployment of computer vision technology becomes increasingly common in science, the need for explanations of the system and its output has become a focus of great concern. Driven by the pressing need for interpretable models in science, we propose the use of Explainable Boosting Machines (EBMs) for scientific image data. Inspired by an important application underpinning the development of quantum technologies, we apply EBMs to cold-atom soliton image data tabularized using Gabor Wavelet Transform-based techniques that preserve the spatial structure of the data. In doing so, we demonstrate the use of EBMs for image data for the first time and show that our approach provides explanations that are consistent with human intuition about the data.

MLFeb 22, 2022
Differentially Private Estimation of Heterogeneous Causal Effects

Fengshi Niu, Harsha Nori, Brian Quistorff et al.

Estimating heterogeneous treatment effects in domains such as healthcare or social science often involves sensitive data where protecting privacy is important. We introduce a general meta-algorithm for estimating conditional average treatment effects (CATE) with differential privacy (DP) guarantees. Our meta-algorithm can work with simple, single-stage CATE estimators such as S-learner and more complex multi-stage estimators such as DR and R-learner. We perform a tight privacy analysis by taking advantage of sample splitting in our meta-algorithm and the parallel composition property of differential privacy. In this paper, we implement our approach using DP-EBMs as the base learner. DP-EBMs are interpretable, high-accuracy models with privacy guarantees, which allow us to directly observe the impact of DP noise on the learned causal model. Our experiments show that multi-stage CATE estimators incur larger accuracy loss than single-stage CATE or ATE estimators and that most of the accuracy loss from differential privacy is due to an increase in variance, not biased estimates of treatment effects.

LGOct 28, 2021
Extracting Expert's Goals by What-if Interpretable Modeling

Chun-Hao Chang, George Alexandru Adam, Rich Caruana et al.

Although reinforcement learning (RL) has tremendous success in many fields, applying RL to real-world settings such as healthcare is challenging when the reward is hard to specify and no exploration is allowed. In this work, we focus on recovering clinicians' rewards in treating patients. We incorporate the what-if reasoning to explain the clinician's treatments based on their potential future outcomes. We use generalized additive models (GAMs) - a class of accurate, interpretable models - to recover the reward. In both simulation and a real-world hospital dataset, we show our model outperforms baselines. Finally, our model's explanations match several clinical guidelines when treating patients while we found the commonly-used linear model often contradicts them.

LGJun 17, 2021
Accuracy, Interpretability, and Differential Privacy via Explainable Boosting

Harsha Nori, Rich Caruana, Zhiqi Bu et al.

We show that adding differential privacy to Explainable Boosting Machines (EBMs), a recent method for training interpretable ML models, yields state-of-the-art accuracy while protecting privacy. Our experiments on multiple classification and regression datasets show that DP-EBM models suffer surprisingly little accuracy loss even with strong differential privacy guarantees. In addition to high accuracy, two other benefits of applying DP to EBMs are: a) trained models provide exact global and local interpretability, which is often important in settings where differential privacy is needed; and b) the models can be edited after training without loss of privacy to correct errors which DP noise may have introduced.

LGJun 3, 2021
NODE-GAM: Neural Generalized Additive Model for Interpretable Deep Learning

Chun-Hao Chang, Rich Caruana, Anna Goldenberg

Deployment of machine learning models in real high-risk settings (e.g. healthcare) often depends not only on the model's accuracy but also on its fairness, robustness, and interpretability. Generalized Additive Models (GAMs) are a class of interpretable models with a long history of use in these high-risk domains, but they lack desirable features of deep learning such as differentiability and scalability. In this work, we propose a neural GAM (NODE-GAM) and neural GA$^2$M (NODE-GA$^2$M) that scale well and perform better than other GAMs on large datasets, while remaining interpretable compared to other ensemble and deep learning models. We demonstrate that our models find interesting patterns in the data. Lastly, we show that we improve model accuracy via self-supervised pre-training, an improvement that is not possible for non-differentiable GAMs.

AO-PHFeb 9, 2021
Sub-seasonal forecasting with a large ensemble of deep-learning weather prediction models

Jonathan A. Weyn, Dale R. Durran, Rich Caruana et al.

We present an ensemble prediction system using a Deep Learning Weather Prediction (DLWP) model that recursively predicts key atmospheric variables with six-hour time resolution. This model uses convolutional neural networks (CNNs) on a cubed sphere grid to produce global forecasts. The approach is computationally efficient, requiring just three minutes on a single GPU to produce a 320-member set of six-week forecasts at 1.4° resolution. Ensemble spread is primarily produced by randomizing the CNN training process to create a set of 32 DLWP models with slightly different learned weights. Although our DLWP model does not forecast precipitation, it does forecast total column water vapor, and it gives a reasonable 4.5-day deterministic forecast of Hurricane Irma. In addition to simulating mid-latitude weather systems, it spontaneously generates tropical cyclones in a one-year free-running simulation. Averaged globally and over a two-year test set, the ensemble mean RMSE retains skill relative to climatology beyond two-weeks, with anomaly correlation coefficients remaining above 0.6 through six days. Our primary application is to subseasonal-to-seasonal (S2S) forecasting at lead times from two to six weeks. Current forecast systems have low skill in predicting one- or 2-week-average weather patterns at S2S time scales. The continuous ranked probability score (CRPS) and the ranked probability skill score (RPSS) show that the DLWP ensemble is only modestly inferior in performance to the European Centre for Medium Range Weather Forecasts (ECMWF) S2S ensemble over land at lead times of 4 and 5-6 weeks. At shorter lead times, the ECMWF ensemble performs better than DLWP.

LGJul 2, 2020
Dropout as a Regularizer of Interaction Effects

Benjamin Lengerich, Eric P. Xing, Rich Caruana

We examine Dropout through the perspective of interactions. This view provides a symmetry to explain Dropout: given $N$ variables, there are ${N \choose k}$ possible sets of $k$ variables to form an interaction (i.e. $\mathcal{O}(N^k)$); conversely, the probability an interaction of $k$ variables survives Dropout at rate $p$ is $(1-p)^k$ (decaying with $k$). These rates effectively cancel, and so Dropout regularizes against higher-order interactions. We prove this perspective analytically and empirically. This perspective of Dropout as a regularizer against interaction effects has several practical implications: (1) higher Dropout rates should be used when we need stronger regularization against spurious high-order interactions, (2) caution should be exercised when interpreting Dropout-based explanations and uncertainty measures, and (3) networks trained with Input Dropout are biased estimators. We also compare Dropout to other regularizers and find that it is difficult to obtain the same selective pressure against high-order interactions.

LGJun 11, 2020
How Interpretable and Trustworthy are GAMs?

Chun-Hao Chang, Sarah Tan, Ben Lengerich et al.

Generalized additive models (GAMs) have become a leading modelclass for interpretable machine learning. However, there are many algorithms for training GAMs, and these can learn different or even contradictory models, while being equally accurate. Which GAM should we trust? In this paper, we quantitatively and qualitatively investigate a variety of GAM algorithms on real and simulated datasets. We find that GAMs with high feature sparsity (only using afew variables to make predictions) can miss patterns in the data and be unfair to rare subpopulations. Our results suggest that inductive bias plays a crucial role in what interpretable models learn and that tree-based GAMs represent the best balance of sparsity, fidelity and accuracy and thus appear to be the most trustworthy GAM.

LGApr 29, 2020
Neural Additive Models: Interpretable Machine Learning with Neural Nets

Rishabh Agarwal, Levi Melnick, Nicholas Frosst et al.

Deep neural networks (DNNs) are powerful black-box predictors that have achieved impressive performance on a wide variety of tasks. However, their accuracy comes at the cost of intelligibility: it is usually unclear how they make their decisions. This hinders their applicability to high stakes decision-making domains such as healthcare. We propose Neural Additive Models (NAMs) which combine some of the expressivity of DNNs with the inherent intelligibility of generalized additive models. NAMs learn a linear combination of neural networks that each attend to a single input feature. These networks are trained jointly and can learn arbitrarily complex relationships between their input feature and the output. Our experiments on regression and classification datasets show that NAMs are more accurate than widely used intelligible models such as logistic regression and shallow decision trees. They perform similarly to existing state-of-the-art generalized additive models in accuracy, but are more flexible because they are based on neural nets instead of boosted trees. To demonstrate this, we show how NAMs can be used for multitask learning on synthetic data and on the COMPAS recidivism data due to their composability, and demonstrate that the differentiability of NAMs allows them to train more complex interpretable models for COVID-19.

AO-PHMar 15, 2020
Improving data-driven global weather prediction using deep convolutional neural networks on a cubed sphere

Jonathan A. Weyn, Dale R. Durran, Rich Caruana

We present a significantly-improved data-driven global weather forecasting framework using a deep convolutional neural network (CNN) to forecast several basic atmospheric variables on a global grid. New developments in this framework include an offline volume-conservative mapping to a cubed-sphere grid, improvements to the CNN architecture, and the minimization of the loss function over multiple steps in a prediction sequence. The cubed-sphere remapping minimizes the distortion on the cube faces on which convolution operations are performed and provides natural boundary conditions for padding in the CNN. Our improved model produces weather forecasts that are indefinitely stable and produce realistic weather patterns at lead times of several weeks and longer. For short- to medium-range forecasting, our model significantly outperforms persistence, climatology, and a coarse-resolution dynamical numerical weather prediction (NWP) model. Unsurprisingly, our forecasts are worse than those from a high-resolution state-of-the-art operational NWP system. Our data-driven model is able to learn to forecast complex surface temperature patterns from few input atmospheric state variables. On annual time scales, our model produces a realistic seasonal cycle driven solely by the prescribed variation in top-of-atmosphere solar forcing. Although it is currently less accurate than operational weather forecasting models, our data-driven CNN executes much faster than those models, suggesting that machine learning could prove to be a valuable tool for large-ensemble forecasting.

MLNov 12, 2019
Purifying Interaction Effects with the Functional ANOVA: An Efficient Algorithm for Recovering Identifiable Additive Models

Benjamin Lengerich, Sarah Tan, Chun-Hao Chang et al.

Models which estimate main effects of individual variables alongside interaction effects have an identifiability challenge: effects can be freely moved between main effects and interaction effects without changing the model prediction. This is a critical problem for interpretability because it permits "contradictory" models to represent the same function. To solve this problem, we propose pure interaction effects: variance in the outcome which cannot be represented by any smaller subset of features. This definition has an equivalence with the Functional ANOVA decomposition. To compute this decomposition, we present a fast, exact algorithm that transforms any piecewise-constant function (such as a tree-based model) into a purified, canonical representation. We apply this algorithm to Generalized Additive Models with interactions trained on several datasets and show large disparity, including contradictions, between the effects before and after purification. These results underscore the need to specify data distributions and ensure identifiability before interpreting model parameters.

LGMay 31, 2019
Efficient Forward Architecture Search

Hanzhang Hu, John Langford, Rich Caruana et al.

We propose a neural architecture search (NAS) algorithm, Petridish, to iteratively add shortcut connections to existing network layers. The added shortcut connections effectively perform gradient boosting on the augmented layers. The proposed algorithm is motivated by the feature selection algorithm forward stage-wise linear regression, since we consider NAS as a generalization of feature selection for regression, where NAS selects shortcuts among layers instead of selecting features. In order to reduce the number of trials of possible connection combinations, we train jointly all possible connections at each stage of growth while leveraging feature selection techniques to choose a subset of them. We experimentally show this process to be an efficient forward architecture search algorithm that can find competitive models using few GPU days in both the search space of repeatable network modules (cell-search) and the space of general networks (macro-search). Petridish is particularly well-suited for warm-starting from existing models crucial for lifelong-learning scenarios.

LGOct 22, 2018
Axiomatic Interpretability for Multiclass Additive Models

Xuezhou Zhang, Sarah Tan, Paul Koch et al.

Generalized additive models (GAMs) are favored in many regression and binary classification problems because they are able to fit complex, nonlinear functions while still remaining interpretable. In the first part of this paper, we generalize a state-of-the-art GAM learning algorithm based on boosted trees to the multiclass setting, and show that this multiclass algorithm outperforms existing GAM learning algorithms and sometimes matches the performance of full complexity models such as gradient boosted trees. In the second part, we turn our attention to the interpretability of GAMs in the multiclass setting. Surprisingly, the natural interpretability of GAMs breaks down when there are more than two classes. Naive interpretation of multiclass GAMs can lead to false conclusions. Inspired by binary GAMs, we identify two axioms that any additive model must satisfy in order to not be visually misleading. We then develop a technique called Additive Post-Processing for Interpretability (API), that provably transforms a pre-trained additive model to satisfy the interpretability axioms without sacrificing accuracy. The technique works not just on models trained with our learning algorithm, but on any multiclass additive model, including multiclass linear and logistic regression. We demonstrate the effectiveness of API on a 12-class infant mortality dataset.

MLJan 26, 2018
Considerations When Learning Additive Explanations for Black-Box Models

Sarah Tan, Giles Hooker, Paul Koch et al.

Many methods to explain black-box models, whether local or global, are additive. In this paper, we study global additive explanations for non-additive models, focusing on four explanation methods: partial dependence, Shapley explanations adapted to a global setting, distilled additive explanations, and gradient-based explanations. We show that different explanation methods characterize non-additive components in a black-box model's prediction function in different ways. We use the concepts of main and total effects to anchor additive explanations, and quantitatively evaluate additive and non-additive explanations. Even though distilled explanations are generally the most accurate additive explanations, non-additive explanations such as tree explanations that explicitly model non-additive components tend to be even more accurate. Despite this, our user study showed that machine learning practitioners were better able to leverage additive explanations for various tasks. These considerations should be taken into account when considering which explanation to trust and use to explain black-box models.

MLOct 17, 2017
Distill-and-Compare: Auditing Black-Box Models Using Transparent Model Distillation

Sarah Tan, Rich Caruana, Giles Hooker et al.

Black-box risk scoring models permeate our lives, yet are typically proprietary or opaque. We propose Distill-and-Compare, a model distillation and comparison approach to audit such models. To gain insight into black-box models, we treat them as teachers, training transparent student models to mimic the risk scores assigned by black-box models. We compare the student model trained with distillation to a second un-distilled transparent model trained on ground-truth outcomes, and use differences between the two models to gain insight into the black-box model. Our approach can be applied in a realistic setting, without probing the black-box model API. We demonstrate the approach on four public data sets: COMPAS, Stop-and-Frisk, Chicago Police, and Lending Club. We also propose a statistical test to determine if a data set is missing key features used to train the black-box model. Our test finds that the ProPublica data is likely missing key feature(s) used in COMPAS.

AIJul 4, 2017
Interpretable & Explorable Approximations of Black Box Models

Himabindu Lakkaraju, Ece Kamar, Rich Caruana et al.

We propose Black Box Explanations through Transparent Approximations (BETA), a novel model agnostic framework for explaining the behavior of any black-box classifier by simultaneously optimizing for fidelity to the original model and interpretability of the explanation. To this end, we develop a novel objective function which allows us to learn (with optimality guarantees), a small number of compact decision sets each of which explains the behavior of the black box model in unambiguous, well-defined regions of feature space. Furthermore, our framework also is capable of accepting user input when generating these approximations, thus allowing users to interactively explore how the black-box model behaves in different subspaces that are of interest to the user. To the best of our knowledge, this is the first approach which can produce global explanations of the behavior of any given black box model through joint optimization of unambiguity, fidelity, and interpretability, while also allowing users to explore model behavior based on their preferences. Experimental evaluation with real-world datasets and user studies demonstrates that our approach can generate highly compact, easy-to-understand, yet accurate approximations of various kinds of predictive models compared to state-of-the-art baselines.

AIOct 28, 2016
Identifying Unknown Unknowns in the Open World: Representations and Policies for Guided Exploration

Himabindu Lakkaraju, Ece Kamar, Rich Caruana et al.

Predictive models deployed in the real world may assign incorrect labels to instances with high confidence. Such errors or unknown unknowns are rooted in model incompleteness, and typically arise because of the mismatch between training data and the cases encountered at test time. As the models are blind to such errors, input from an oracle is needed to identify these failures. In this paper, we formulate and address the problem of informed discovery of unknown unknowns of any given predictive model where unknown unknowns occur due to systematic biases in the training data. We propose a model-agnostic methodology which uses feedback from an oracle to both identify unknown unknowns and to intelligently guide the discovery. We employ a two-phase approach which first organizes the data into multiple partitions based on the feature similarity of instances and the confidence scores assigned by the predictive model, and then utilizes an explore-exploit strategy for discovering unknown unknowns across these partitions. We demonstrate the efficacy of our framework by varying the underlying causes of unknown unknowns across various applications. To the best of our knowledge, this paper presents the first algorithmic approach to the problem of discovering unknown unknowns of predictive models.

MLMar 17, 2016
Do Deep Convolutional Nets Really Need to be Deep and Convolutional?

Gregor Urban, Krzysztof J. Geras, Samira Ebrahimi Kahou et al.

Yes, they do. This paper provides the first empirical demonstration that deep convolutional models really need to be both deep and convolutional, even when trained with methods such as distillation that allow small or shallow models of high accuracy to be trained. Although previous research showed that shallow feed-forward nets sometimes can learn the complex functions previously learned by deep nets while using the same number of parameters as the deep models they mimic, in this paper we demonstrate that the same methods cannot be used to train accurate models on CIFAR-10 unless the student models contain multiple layers of convolution. Although the student models do not have to be as deep as the teacher model they mimic, the students need multiple convolutional layers to learn functions of comparable accuracy as the deep convolutional teacher.

IRFeb 2, 2016
A Dual Embedding Space Model for Document Ranking

Bhaskar Mitra, Eric Nalisnick, Nick Craswell et al.

A fundamental goal of search engines is to identify, given a query, documents that have relevant text. This is intrinsically difficult because the query and the document may use different vocabulary, or the document may contain query words without being relevant. We investigate neural word embeddings as a source of evidence in document ranking. We train a word2vec embedding model on a large unlabelled query corpus, but in contrast to how the model is commonly used, we retain both the input and the output projections, allowing us to leverage both the embedding spaces to derive richer distributional relationships. During ranking we map the query words into the input space and the document words into the output space, and compute a query-document relevance score by aggregating the cosine similarities across all the query-document word pairs. We postulate that the proposed Dual Embedding Space Model (DESM) captures evidence on whether a document is about a query term in addition to what is modelled by traditional term-frequency based approaches. Our experiments show that the DESM can re-rank top documents returned by a commercial Web search engine, like Bing, better than a term-matching based signal like TF-IDF. However, when ranking a larger set of candidate documents, we find the embeddings-based approach is prone to false positives, retrieving documents that are only loosely related to the query. We demonstrate that this problem can be solved effectively by ranking based on a linear mixture of the DESM and the word counting features.

LGNov 19, 2015
Blending LSTMs into CNNs

Krzysztof J. Geras, Abdel-rahman Mohamed, Rich Caruana et al.

We consider whether deep convolutional networks (CNNs) can represent decision functions with similar accuracy as recurrent networks such as LSTMs. First, we show that a deep CNN with an architecture inspired by the models recently introduced in image recognition can yield better accuracy than previous convolutional and LSTM networks on the standard 309h Switchboard automatic speech recognition task. Then we show that even more accurate CNNs can be trained under the guidance of LSTMs using a variant of model compression, which we call model blending because the teacher and student models are similar in complexity but different in inductive bias. Blending further improves the accuracy of our CNN, yielding a computationally efficient model of accuracy higher than any of the other individual models. Examining the effect of "dark knowledge" in this model compression task, we find that less than 1% of the highest probability labels are needed for accurate model compression.

MEJul 17, 2014
Sparse Partially Linear Additive Models

Yin Lou, Jacob Bien, Rich Caruana et al.

The generalized partially linear additive model (GPLAM) is a flexible and interpretable approach to building predictive models. It combines features in an additive manner, allowing each to have either a linear or nonlinear effect on the response. However, the choice of which features to treat as linear or nonlinear is typically assumed known. Thus, to make a GPLAM a viable approach in situations in which little is known $a~priori$ about the features, one must overcome two primary model selection challenges: deciding which features to include in the model and determining which of these features to treat nonlinearly. We introduce the sparse partially linear additive model (SPLAM), which combines model fitting and $both$ of these model selection challenges into a single convex optimization problem. SPLAM provides a bridge between the lasso and sparse additive models. Through a statistical oracle inequality and thorough simulation, we demonstrate that SPLAM can outperform other methods across a broad spectrum of statistical regimes, including the high-dimensional ($p\gg N$) setting. We develop efficient algorithms that are applied to real data sets with half a million samples and over 45,000 features with excellent predictive performance.

LGDec 21, 2013
Do Deep Nets Really Need to be Deep?

Lei Jimmy Ba, Rich Caruana

Currently, deep neural networks are the state of the art on problems such as speech recognition and computer vision. In this extended abstract, we show that shallow feed-forward networks can learn the complex functions previously learned by deep nets and achieve accuracies previously only achievable with deep models. Moreover, in some cases the shallow neural nets can learn these deep functions using a total number of parameters similar to the original deep model. We evaluate our method on the TIMIT phoneme recognition task and are able to train shallow fully-connected nets that perform similarly to complex, well-engineered, deep convolutional architectures. Our success in training shallow neural nets to mimic deeper models suggests that there probably exist better algorithms for training shallow feed-forward nets than those currently available.

MLNov 28, 2013
Using Multiple Samples to Learn Mixture Models

Jason D Lee, Ran Gilad-Bachrach, Rich Caruana

In the mixture models problem it is assumed that there are $K$ distributions $θ_{1},\ldots,θ_{K}$ and one gets to observe a sample from a mixture of these distributions with unknown coefficients. The goal is to associate instances with their generating distributions, or to identify the parameters of the hidden distributions. In this work we make the assumption that we have access to several samples drawn from the same $K$ underlying distributions, but with different mixing weights. As with topic modeling, having multiple samples is often a reasonable assumption. Instead of pooling the data into one sample, we prove that it is possible to use the differences between the samples to better recover the underlying structure. We present algorithms that recover the underlying structure under milder assumptions than the current state of art when either the dimensionality or the separation is high. The methods, when applied to topic modeling, allow generalization to words not present in the training data.