Bin Yu

ML
h-index36
35papers
4,612citations
Novelty55%
AI Score55

35 Papers

21.4LGApr 17, 2023Code
Bridging Discrete and Backpropagation: Straight-Through and Beyond

Liyuan Liu, Chengyu Dong, Xiaodong Liu et al.

Backpropagation, the cornerstone of deep learning, is limited to computing gradients for continuous variables. This limitation poses challenges for problems involving discrete latent variables. To address this issue, we propose a novel approach to approximate the gradient of parameters involved in generating discrete latent variables. First, we examine the widely used Straight-Through (ST) heuristic and demonstrate that it works as a first-order approximation of the gradient. Guided by our findings, we propose ReinMax, which achieves second-order accuracy by integrating Heun's method, a second-order numerical method for solving ODEs. ReinMax does not require Hessian or other second-order derivatives, thus having negligible computation overheads. Extensive experimental results on various tasks demonstrate the superiority of ReinMax over the state of the art. Implementations are released at https://github.com/microsoft/ReinMax.

3.3LGMay 30, 2022Code
Group Probability-Weighted Tree Sums for Interpretable Modeling of Heterogeneous Data

Keyan Nasseri, Chandan Singh, James Duncan et al. · berkeley

Machine learning in high-stakes domains, such as healthcare, faces two critical challenges: (1) generalizing to diverse data distributions given limited training data while (2) maintaining interpretability. To address these challenges, we propose an instance-weighted tree-sum method that effectively pools data across diverse groups to output a concise, rule-based model. Given distinct groups of instances in a dataset (e.g., medical patients grouped by age or treatment site), our method first estimates group membership probabilities for each instance. Then, it uses these estimates as instance weights in FIGS (Tan et al. 2022), to grow a set of decision trees whose values sum to the final prediction. We call this new method Group Probability-Weighted Tree Sums (G-FIGS). G-FIGS achieves state-of-the-art prediction performance on important clinical datasets; e.g., holding the level of sensitivity fixed at 92%, G-FIGS increases specificity for identifying cervical spine injury by up to 10% over CART and up to 3% over FIGS alone, with larger gains at higher sensitivity levels. By keeping the total number of rules below 16 in FIGS, the final models remain interpretable, and we find that their rules match medical domain expertise. All code, data, and models are released on Github.

1.4CVJun 17, 2022
Learning Using Privileged Information for Zero-Shot Action Recognition

Zhiyi Gao, Yonghong Hou, Wanqing Li et al.

Zero-Shot Action Recognition (ZSAR) aims to recognize video actions that have never been seen during training. Most existing methods assume a shared semantic space between seen and unseen actions and intend to directly learn a mapping from a visual space to the semantic space. This approach has been challenged by the semantic gap between the visual space and semantic space. This paper presents a novel method that uses object semantics as privileged information to narrow the semantic gap and, hence, effectively, assist the learning. In particular, a simple hallucination network is proposed to implicitly extract object semantics during testing without explicitly extracting objects and a cross-attention module is developed to augment visual feature with the object semantics. Experiments on the Olympic Sports, HMDB51 and UCF101 datasets have shown that the proposed method outperforms the state-of-the-art methods by a large margin.

8.9MLOct 17, 2022
A Mixing Time Lower Bound for a Simplified Version of BART

Omer Ronen, Theo Saarinen, Yan Shuo Tan et al.

Bayesian Additive Regression Trees (BART) is a popular Bayesian non-parametric regression algorithm. The posterior is a distribution over sums of decision trees, and predictions are made by averaging approximate samples from the posterior. The combination of strong predictive performance and the ability to provide uncertainty measures has led BART to be commonly used in the social sciences, biostatistics, and causal inference. BART uses Markov Chain Monte Carlo (MCMC) to obtain approximate posterior samples over a parameterized space of sums of trees, but it has often been observed that the chains are slow to mix. In this paper, we provide the first lower bound on the mixing time for a simplified version of BART in which we reduce the sum to a single tree and use a subset of the possible moves for the MCMC proposal distribution. Our lower bound for the mixing time grows exponentially with the number of data points. Inspired by this new connection between the mixing time and the number of data points, we perform rigorous simulations on BART. We show qualitatively that BART's mixing time increases with the number of data points. The slow mixing time of the simplified BART suggests a large variation between different runs of the simplified BART algorithm and a similar large variation is known for BART in the literature. This large variation could result in a lack of stability in the models, predictions, and posterior intervals obtained from the BART MCMC samples. Our lower bound and simulations suggest increasing the number of chains with the number of data points.

18.9CLJan 25, 2024Code
Towards Consistent Natural-Language Explanations via Explanation-Consistency Finetuning

Yanda Chen, Chandan Singh, Xiaodong Liu et al.

Large language models (LLMs) often generate convincing, fluent explanations. However, different from humans, they often generate inconsistent explanations on different inputs. For example, an LLM may generate the explanation "all birds can fly" when answering the question "Can sparrows fly?" but meanwhile answer "no" to the related question "Can penguins fly?". Explanations should be consistent across related examples so that they allow a human to simulate the LLM's decision process on multiple examples. We propose explanation-consistency finetuning (EC-finetuning), a method that adapts LLMs to generate more consistent natural-language explanations on related examples. EC-finetuning involves finetuning LLMs on synthetic data that is carefully constructed to contain consistent explanations. Across a variety of question-answering datasets in various domains, EC-finetuning yields a 10.0% relative explanation consistency improvement on four finetuning datasets, and generalizes to seven out-of-distribution datasets not seen during finetuning (+4.5% relative). Code is available at https://github.com/yandachen/explanation-consistency-finetuning .

17.3LGFeb 2, 2022Code
Hierarchical Shrinkage: improving the accuracy and interpretability of tree-based methods

Abhineet Agarwal, Yan Shuo Tan, Omer Ronen et al.

Tree-based models such as decision trees and random forests (RF) are a cornerstone of modern machine-learning practice. To mitigate overfitting, trees are typically regularized by a variety of techniques that modify their structure (e.g. pruning). We introduce Hierarchical Shrinkage (HS), a post-hoc algorithm that does not modify the tree structure, and instead regularizes the tree by shrinking the prediction over each node towards the sample means of its ancestors. The amount of shrinkage is controlled by a single regularization parameter and the number of data points in each ancestor. Since HS is a post-hoc method, it is extremely fast, compatible with any tree growing algorithm, and can be used synergistically with other regularization techniques. Extensive experiments over a wide variety of real-world datasets show that HS substantially increases the predictive performance of decision trees, even when used in conjunction with other regularization techniques. Moreover, we find that applying HS to each tree in an RF often improves accuracy, as well as its interpretability by simplifying and stabilizing its decision boundaries and SHAP values. We further explain the success of HS in improving prediction performance by showing its equivalence to ridge regression on a (supervised) basis constructed of decision stumps associated with the internal nodes of a tree. All code and models are released in a full-fledged package available on Github (github.com/csinva/imodels)

24.6SEJun 10, 2025Code
UTBoost: Rigorous Evaluation of Coding Agents on SWE-Bench

Boxi Yu, Yuxuan Zhu, Pinjia He et al.

The advent of Large Language Models (LLMs) has spurred the development of coding agents for real-world code generation. As a widely used benchmark for evaluating the code generation capabilities of these agents, SWE-Bench uses real-world problems based on GitHub issues and their corresponding pull requests. However, the manually written test cases included in these pull requests are often insufficient, allowing generated patches to pass the tests without resolving the underlying issue. To address this challenge, we introduce UTGenerator, an LLM-driven test case generator that automatically analyzes codebases and dependencies to generate test cases for real-world Python projects. Building on UTGenerator, we propose UTBoost, a comprehensive framework for test case augmentation. In our evaluation, we identified 36 task instances with insufficient test cases and uncovered 345 erroneous patches incorrectly labeled as passed in the original SWE Bench. These corrections, impacting 40.9% of SWE-Bench Lite and 24.4% of SWE-Bench Verified leaderboard entries, yield 18 and 11 ranking changes, respectively.

14.4LGSep 10, 2025
EvolKV: Evolutionary KV Cache Compression for LLM Inference

Bohan Yu, Yekun Chai

Existing key-value (KV) cache compression methods typically rely on heuristics, such as uniform cache allocation across layers or static eviction policies, however, they ignore the critical interplays among layer-specific feature patterns and task performance, which can lead to degraded generalization. In this paper, we propose EvolKV, an adaptive framework for layer-wise, task-driven KV cache compression that jointly optimizes the memory efficiency and task performance. By reformulating cache allocation as a multi-objective optimization problem, EvolKV leverages evolutionary search to dynamically configure layer budgets while directly maximizing downstream performance. Extensive experiments on 11 tasks demonstrate that our approach outperforms all baseline methods across a wide range of KV cache budgets on long-context tasks and surpasses heuristic baselines by up to 7 percentage points on GSM8K. Notably, EvolKV achieves superior performance over the full KV cache setting on code completion while utilizing only 1.5% of the original budget, suggesting the untapped potential in learned compression strategies for KV cache budget allocation.

15.5CVDec 14, 2025
StreamingAssistant: Efficient Visual Token Pruning for Accelerating Online Video Understanding

Xinqi Jin, Hanxun Yu, Bohan Yu et al.

Online video understanding is essential for applications like public surveillance and AI glasses. However, applying Multimodal Large Language Models (MLLMs) to this domain is challenging due to the large number of video frames, resulting in high GPU memory usage and computational latency. To address these challenges, we propose token pruning as a means to reduce context length while retaining critical information. Specifically, we introduce a novel redundancy metric, Maximum Similarity to Spatially Adjacent Video Tokens (MSSAVT), which accounts for both token similarity and spatial position. To mitigate the bidirectional dependency between pruning and redundancy, we further design a masked pruning strategy that ensures only mutually unadjacent tokens are pruned. We also integrate an existing temporal redundancy-based pruning method to eliminate temporal redundancy of the video modality. Experimental results on multiple online and offline video understanding benchmarks demonstrate that our method significantly improves the accuracy (i.e., by 4\% at most) while incurring a negligible pruning latency (i.e., less than 1ms). Our full implementation will be made publicly available.

7.1LGJun 10, 2025
Local MDI+: Local Feature Importances for Tree-Based Models

Zhongyuan Liang, Zachary T. Rewolinski, Abhineet Agarwal et al.

Tree-based ensembles such as random forests remain the go-to for tabular data over deep learning models due to their prediction performance and computational efficiency. These advantages have led to their widespread deployment in high-stakes domains, where interpretability is essential for ensuring trustworthy predictions. This has motivated the development of popular local (i.e. sample-specific) feature importance (LFI) methods such as LIME and TreeSHAP. However, these approaches rely on approximations that ignore the model's internal structure and instead depend on potentially unstable perturbations. These issues are addressed in the global setting by MDI+, a feature importance method which exploits an equivalence between decision trees and linear models on a transformed node basis. However, the global MDI+ scores are not able to explain predictions when faced with heterogeneous individual characteristics. To address this gap, we propose Local MDI+ (LMDI+), a novel extension of the MDI+ framework to the sample specific setting. LMDI+ outperforms existing baselines LIME and TreeSHAP in identifying instance-specific signal features, averaging a 10% improvement in downstream task performance across twelve real-world benchmark datasets. It further demonstrates greater stability by consistently producing similar instance-level feature importance rankings across multiple random forest fits. Finally, LMDI+ enables local interpretability use cases, including the identification of closer counterfactuals and the discovery of homogeneous subgroups.

5.5MLJun 28, 2024Code
On the Computational Efficiency of Bayesian Additive Regression Trees: An Asymptotic Analysis

Yan Shuo Tan, Omer Ronen, Theo Saarinen et al.

Bayesian Additive Regression Trees (BART) is a popular Bayesian non-parametric regression model that is commonly used in causal inference and beyond. Its strong predictive performance is supported by well-developed estimation theory, comprising guarantees that its posterior distribution concentrates around the true regression function at optimal rates under various data generative settings and for appropriate prior choices. However, the computational properties of the widely-used BART sampler proposed by Chipman et al. (2010) are yet to be well-understood. In this paper, we perform an asymptotic analysis of a slightly modified version of the default BART sampler when fitted to data-generating processes with discrete covariates. We show that the sampler's time to convergence, evaluated in terms of the hitting time of a high posterior density set, increases with the number of training samples, due to the multi-modal nature of the target posterior. On the other hand, we show that this trend can be dampened by simple changes, such as increasing the number of trees in the ensemble or raising the temperature of the sampler. These results provide a nuanced picture on the computational efficiency of the BART sampler in the presence of large amounts of training data while suggesting strategies to improve the sampler. We complement our theoretical analysis with a simulation study focusing on the default BART sampler. We observe that the increasing trend of convergence time against number training samples holds for the default BART sampler and is robust to changes in sampler initialization, number of burn-in iterations, feature selection prior, and discretization strategy. On the other hand, increasing the number of trees or raising the temperature sharply dampens this trend, as indicated by our theory.

0.9CLMay 27, 2023Code
Diagnosing Transformers: Illuminating Feature Spaces for Clinical Decision-Making

Aliyah R. Hsu, Yeshwanth Cherapanamjeri, Briton Park et al.

Pre-trained transformers are often fine-tuned to aid clinical decision-making using limited clinical notes. Model interpretability is crucial, especially in high-stakes domains like medicine, to establish trust and ensure safety, which requires human engagement. We introduce SUFO, a systematic framework that enhances interpretability of fine-tuned transformer feature spaces. SUFO utilizes a range of analytic and visualization techniques, including Supervised probing, Unsupervised similarity analysis, Feature dynamics, and Outlier analysis to address key questions about model trust and interpretability. We conduct a case study investigating the impact of pre-training data where we focus on real-world pathology classification tasks, and validate our findings on MedNLI. We evaluate five 110M-sized pre-trained transformer models, categorized into general-domain (BERT, TNLR), mixed-domain (BioBERT, Clinical BioBERT), and domain-specific (PubMedBERT) groups. Our SUFO analyses reveal that: (1) while PubMedBERT, the domain-specific model, contains valuable information for fine-tuning, it can overfit to minority classes when class imbalances exist. In contrast, mixed-domain models exhibit greater resistance to overfitting, suggesting potential improvements in domain-specific model robustness; (2) in-domain pre-training accelerates feature disambiguation during fine-tuning; and (3) feature spaces undergo significant sparsification during this process, enabling clinicians to identify common outlier modes among fine-tuned models as demonstrated in this paper. These findings showcase the utility of SUFO in enhancing trust and safety when using transformers in medicine, and we believe SUFO can aid practitioners in evaluating fine-tuned language models for other applications in medicine and in more critical domains.

28.1AIMay 17, 2023Code
Explaining black box text modules in natural language with language models

Chandan Singh, Aliyah R. Hsu, Richard Antonello et al.

Large language models (LLMs) have demonstrated remarkable prediction performance for a growing array of tasks. However, their rapid proliferation and increasing opaqueness have created a growing need for interpretability. Here, we ask whether we can automatically obtain natural language explanations for black box text modules. A "text module" is any function that maps text to a scalar continuous value, such as a submodule within an LLM or a fitted model of a brain region. "Black box" indicates that we only have access to the module's inputs/outputs. We introduce Summarize and Score (SASC), a method that takes in a text module and returns a natural language explanation of the module's selectivity along with a score for how reliable the explanation is. We study SASC in 3 contexts. First, we evaluate SASC on synthetic modules and find that it often recovers ground truth explanations. Second, we use SASC to explain modules found within a pre-trained BERT model, enabling inspection of the model's internals. Finally, we show that SASC can generate explanations for the response of individual fMRI voxels to language stimuli, with potential applications to fine-grained brain mapping. All code for using SASC and reproducing results is made available on Github.

17.4MLNov 13, 2021
The Three Stages of Learning Dynamics in High-Dimensional Kernel Methods

Nikhil Ghosh, Song Mei, Bin Yu

To understand how deep learning works, it is crucial to understand the training dynamics of neural networks. Several interesting hypotheses about these dynamics have been made based on empirically observed phenomena, but there exists a limited theoretical understanding of when and why such phenomena occur. In this paper, we consider the training dynamics of gradient flow on kernel least-squares objectives, which is a limiting dynamics of SGD trained neural networks. Using precise high-dimensional asymptotics, we characterize the dynamics of the fitted model in two "worlds": in the Oracle World the model is trained on the population distribution and in the Empirical World the model is trained on a sampled dataset. We show that under mild conditions on the kernel and $L^2$ target regression function the training dynamics undergo three stages characterized by the behaviors of the models in the two worlds. Our theoretical results also mathematically formalize some interesting deep learning phenomena. Specifically, in our setting we show that SGD progressively learns more complex functions and that there is a "deep bootstrap" phenomenon: during the second stage, the test error of both worlds remain close despite the empirical training error being much smaller. Finally, we give a concrete example comparing the dynamics of two different kernels which shows that faster training is not necessary for better generalization.

3.6MLAug 16, 2021Code
Interpreting and improving deep-learning models with reality checks

Chandan Singh, Wooseok Ha, Bin Yu

Recent deep-learning models have achieved impressive predictive performance by learning complex functions of many variables, often at the cost of interpretability. This chapter covers recent work aiming to interpret models by attributing importance to features and feature groups for a single prediction. Importantly, the proposed attributions assign importance to interactions between features, in addition to features in isolation. These attributions are shown to yield insights across real-world domains, including bio-imaging, cosmology image and natural-language processing. We then show how these attributions can be used to directly improve the generalization of a neural network or to distill it into a simple model. Throughout the chapter, we emphasize the use of reality checks to scrutinize the proposed interpretation techniques.

11.7MEAug 23, 2020Code
Stable discovery of interpretable subgroups via calibration in causal studies

Raaz Dwivedi, Yan Shuo Tan, Briton Park et al.

Building on Yu and Kumbier's PCS framework and for randomized experiments, we introduce a novel methodology for Stable Discovery of Interpretable Subgroups via Calibration (StaDISC), with large heterogeneous treatment effects. StaDISC was developed during our re-analysis of the 1999-2000 VIGOR study, an 8076 patient randomized controlled trial (RCT), that compared the risk of adverse events from a then newly approved drug, Rofecoxib (Vioxx), to that from an older drug Naproxen. Vioxx was found to, on average and in comparison to Naproxen, reduce the risk of gastrointestinal (GI) events but increase the risk of thrombotic cardiovascular (CVT) events. Applying StaDISC, we fit 18 popular conditional average treatment effect (CATE) estimators for both outcomes and use calibration to demonstrate their poor global performance. However, they are locally well-calibrated and stable, enabling the identification of patient groups with larger than (estimated) average treatment effects. In fact, StaDISC discovers three clinically interpretable subgroups each for the GI outcome (totaling 29.4% of the study size) and the CVT outcome (totaling 11.0%). Complementary analyses of the found subgroups using the 2001-2004 APPROVe study, a separate independently conducted RCT with 2587 patients, provides further supporting evidence for the promise of StaDISC.

10.1LGJun 17, 2020Code
Revisiting minimum description length complexity in overparameterized models

Raaz Dwivedi, Chandan Singh, Bin Yu et al.

Complexity is a fundamental concept underlying statistical learning theory that aims to inform generalization performance. Parameter count, while successful in low-dimensional settings, is not well-justified for overparameterized settings when the number of parameters is more than the number of training samples. We revisit complexity measures based on Rissanen's principle of minimum description length (MDL) and define a novel MDL-based complexity (MDL-COMP) that remains valid for overparameterized models. MDL-COMP is defined via an optimality criterion over the encodings induced by a good Ridge estimator class. We provide an extensive theoretical characterization of MDL-COMP for linear models and kernel methods and show that it is not just a function of parameter count, but rather a function of the singular values of the design or the kernel matrix and the signal-to-noise ratio. For a linear model with $n$ observations, $d$ parameters, and i.i.d. Gaussian predictors, MDL-COMP scales linearly with $d$ when $d<n$, but the scaling is exponentially smaller -- $\log d$ for $d>n$. For kernel methods, we show that MDL-COMP informs minimax in-sample error, and can decrease as the dimensionality of the input increases. We also prove that MDL-COMP upper bounds the in-sample mean squared error (MSE). Via an array of simulations and real-data experiments, we show that a data-driven Prac-MDL-COMP informs hyper-parameter tuning for optimizing test MSE with ridge regression in limited data settings, sometimes improving upon cross-validation and (always) saving computational costs. Finally, our findings also suggest that the recently observed double decent phenomenons in overparameterized models might be a consequence of the choice of non-ideal estimators.

11.1LGMay 22, 2020
Instability, Computational Efficiency and Statistical Accuracy

Nhat Ho, Koulik Khamaru, Raaz Dwivedi et al.

Many statistical estimators are defined as the fixed point of a data-dependent operator, with estimators based on minimizing a cost function being an important special case. The limiting performance of such estimators depends on the properties of the population-level operator in the idealized limit of infinitely many samples. We develop a general framework that yields bounds on statistical accuracy based on the interplay between the deterministic convergence rate of the algorithm at the population level, and its degree of (in)stability when applied to an empirical object based on $n$ samples. Using this framework, we analyze both stable forms of gradient descent and some higher-order and unstable algorithms, including Newton's method and its cubic-regularized variant, as well as the EM algorithm. We provide applications of our general results to several concrete classes of models, including Gaussian mixture estimation, non-linear regression models, and informative non-response models. We exhibit cases in which an unstable algorithm can achieve the same statistical accuracy as a stable algorithm in exponentially fewer steps -- namely, with the number of iterations being reduced from polynomial to logarithmic in sample size $n$.

17.1MLJun 26, 2019Code
A Debiased MDI Feature Importance Measure for Random Forests

Xiao Li, Yu Wang, Sumanta Basu et al.

Tree ensembles such as Random Forests have achieved impressive empirical success across a wide variety of applications. To understand how these models make predictions, people routinely turn to feature importance measures calculated from tree ensembles. It has long been known that Mean Decrease Impurity (MDI), one of the most widely used measures of feature importance, incorrectly assigns high importance to noisy features, leading to systematic bias in feature selection. In this paper, we address the feature selection bias of MDI from both theoretical and methodological perspectives. Based on the original definition of MDI by Breiman et al. for a single tree, we derive a tight non-asymptotic bound on the expected bias of MDI importance of noisy features, showing that deep trees have higher (expected) feature selection bias than shallow ones. However, it is not clear how to reduce the bias of MDI using its existing analytical expression. We derive a new analytical expression for MDI, and based on this new expression, we are able to propose a debiased MDI feature importance measure using out-of-bag samples, called MDI-oob. For both the simulated data and a genomic ChIP dataset, MDI-oob achieves state-of-the-art performance in feature selection from Random Forests for both deep and shallow trees.

12.6MLMay 18, 2019Code
Disentangled Attribution Curves for Interpreting Random Forests and Boosted Trees

Summer Devlin, Chandan Singh, W. James Murdoch et al.

Tree ensembles, such as random forests and AdaBoost, are ubiquitous machine learning models known for achieving strong predictive performance across a wide variety of domains. However, this strong performance comes at the cost of interpretability (i.e. users are unable to understand the relationships a trained random forest has learned and why it is making its predictions). In particular, it is challenging to understand how the contribution of a particular feature, or group of features, varies as their value changes. To address this, we introduce Disentangled Attribution Curves (DAC), a method to provide interpretations of tree ensemble methods in the form of (multivariate) feature importance curves. For a given variable, or group of variables, DAC plots the importance of a variable(s) as their value changes. We validate DAC on real data by showing that the curves can be used to increase the accuracy of logistic regression while maintaining interpretability, by including DAC as an additional feature. In simulation studies, DAC is shown to out-perform competing methods in the recovery of conditional expectations. Finally, through a case-study on the bike-sharing dataset, we demonstrate the use of DAC to uncover novel insights into a dataset.

16.1STFeb 1, 2019
Sharp Analysis of Expectation-Maximization for Weakly Identifiable Models

Raaz Dwivedi, Nhat Ho, Koulik Khamaru et al.

We study a class of weakly identifiable location-scale mixture models for which the maximum likelihood estimates based on $n$ i.i.d. samples are known to have lower accuracy than the classical $n^{- \frac{1}{2}}$ error. We investigate whether the Expectation-Maximization (EM) algorithm also converges slowly for these models. We provide a rigorous characterization of EM for fitting a weakly identifiable Gaussian mixture in a univariate setting where we prove that the EM algorithm converges in order $n^{\frac{3}{4}}$ steps and returns estimates that are at a Euclidean distance of order ${ n^{- \frac{1}{8}}}$ and ${ n^{-\frac{1} {4}}}$ from the true location and scale parameter respectively. Establishing the slow rates in the univariate setting requires a novel localization argument with two stages, with each stage involving an epoch-based argument applied to a different surrogate EM operator at the population level. We demonstrate several multivariate ($d \geq 2$) examples that exhibit the same slow rates as the univariate case. We also prove slow statistical rates in higher dimensions in a special case, when the fitted covariance is constrained to be a multiple of the identity.

40.4MLJan 14, 2019Code
Interpretable machine learning: definitions, methods, and applications

W. James Murdoch, Chandan Singh, Karl Kumbier et al.

Machine-learning models have demonstrated great success in learning complex patterns that enable them to make predictions about unobserved data. In addition to using models for prediction, the ability to interpret what a model has learned is receiving an increasing amount of attention. However, this increased focus has led to considerable confusion about the notion of interpretability. In particular, it is unclear how the wide array of proposed interpretation methods are related, and what common concepts can be used to evaluate them. We aim to address these concerns by defining interpretability in the context of machine learning and introducing the Predictive, Descriptive, Relevant (PDR) framework for discussing interpretations. The PDR framework provides three overarching desiderata for evaluation: predictive accuracy, descriptive accuracy and relevancy, with relevancy judged relative to a human audience. Moreover, to help manage the deluge of interpretation methods, we introduce a categorization of existing techniques into model-based and post-hoc categories, with sub-groups including sparsity, modularity and simulatability. To demonstrate how practitioners can use the PDR framework to evaluate and understand interpretations, we provide numerous real-world examples. These examples highlight the often under-appreciated role played by human audiences in discussions of interpretability. Finally, based on our framework, we discuss limitations of existing methods and directions for future work. We hope that this work will provide a common vocabulary that will make it easier for both practitioners and researchers to discuss and choose from the full range of interpretation methods.

20.2STOct 1, 2018
Singularity, Misspecification, and the Convergence Rate of EM

Raaz Dwivedi, Nhat Ho, Koulik Khamaru et al.

A line of recent work has analyzed the behavior of the Expectation-Maximization (EM) algorithm in the well-specified setting, in which the population likelihood is locally strongly concave around its maximizing argument. Examples include suitably separated Gaussian mixture models and mixtures of linear regressions. We consider over-specified settings in which the number of fitted components is larger than the number of components in the true distribution. Such misspecified settings can lead to singularity in the Fisher information matrix, and moreover, the maximum likelihood estimator based on $n$ i.i.d. samples in $d$ dimensions can have a non-standard $\mathcal{O}((d/n)^{\frac{1}{4}})$ rate of convergence. Focusing on the simple setting of two-component mixtures fit to a $d$-dimensional Gaussian distribution, we study the behavior of the EM algorithm both when the mixture weights are different (unbalanced case), and are equal (balanced case). Our analysis reveals a sharp distinction between these two cases: in the former, the EM algorithm converges geometrically to a point at Euclidean distance of $\mathcal{O}((d/n)^{\frac{1}{2}})$ from the true parameter, whereas in the latter case, the convergence rate is exponentially slower, and the fixed point has a much lower $\mathcal{O}((d/n)^{\frac{1}{4}})$ accuracy. Analysis of this singular case requires the introduction of some novel techniques: in particular, we make use of a careful form of localization in the associated empirical process, and develop a recursive argument to progressively sharpen the statistical rate.

25.9LGJun 14, 2018Code
Hierarchical interpretations for neural network predictions

Chandan Singh, W. James Murdoch, Bin Yu

Deep neural networks (DNNs) have achieved impressive predictive performance due to their ability to learn complex, non-linear relationships between variables. However, the inability to effectively visualize these relationships has led to DNNs being characterized as black boxes and consequently limited their applications. To ameliorate this problem, we introduce the use of hierarchical interpretations to explain DNN predictions through our proposed method, agglomerative contextual decomposition (ACD). Given a prediction from a trained DNN, ACD produces a hierarchical clustering of the input features, along with the contribution of each cluster to the final prediction. This hierarchy is optimized to identify clusters of features that the DNN learned are predictive. Using examples from Stanford Sentiment Treebank and ImageNet, we show that ACD is effective at diagnosing incorrect predictions and identifying dataset bias. Through human experiments, we demonstrate that ACD enables users both to identify the more accurate of two DNNs and to better trust a DNN's outputs. We also find that ACD's hierarchy is largely robust to adversarial perturbations, implying that it captures fundamental aspects of the input and ignores spurious noise.

25.8MLApr 4, 2018
Stability and Convergence Trade-off of Iterative Optimization Algorithms

Yuansi Chen, Chi Jin, Bin Yu

The overall performance or expected excess risk of an iterative machine learning algorithm can be decomposed into training error and generalization error. While the former is controlled by its convergence analysis, the latter can be tightly handled by algorithmic stability. The machine learning community has a rich history investigating convergence and stability separately. However, the question about the trade-off between these two quantities remains open. In this paper, we show that for any iterative algorithm at any iteration, the overall performance is lower bounded by the minimax statistical error over an appropriately chosen loss function class. This implies an important trade-off between convergence and stability of the algorithm -- a faster converging algorithm has to be less stable, and vice versa. As a direct consequence of this fundamental tradeoff, new convergence lower bounds can be derived for classes of algorithms constrained with different stability bounds. In particular, when the loss function is convex (or strongly convex) and smooth, we discuss the stability upper bounds of gradient descent (GD) and stochastic gradient descent and their variants with decreasing step sizes. For Nesterov's accelerated gradient descent (NAG) and heavy ball method (HB), we provide stability upper bounds for the quadratic loss function. Applying existing stability upper bounds for the gradient methods in our trade-off framework, we obtain lower bounds matching the well-established convergence upper bounds up to constants for these algorithms and conjecture similar lower bounds for NAG and HB. Finally, we numerically demonstrate the tightness of our stability bounds in terms of exponents in the rate and also illustrate via a simulated logistic regression problem that our stability bounds reflect the generalization errors better than the simple uniform convergence bounds for GD and NAG.

52.5MLDec 8, 2017
Artificial Intelligence and Statistics

Bin Yu, Karl Kumbier

Artificial intelligence (AI) is intrinsically data-driven. It calls for the application of statistical concepts through human-machine collaboration during generation of data, development of algorithms, and evaluation of results. This paper discusses how such human-machine collaboration can be approached through the statistical concepts of population, question of interest, representativeness of training data, and scrutiny of results (PQRS). The PQRS workflow provides a conceptual framework for integrating statistical ideas with human input into AI products and research. These ideas include experimental design principles of randomization and local control as well as the principle of stability to gain reproducibility and interpretability of algorithms and data results. We discuss the use of these principles in the contexts of self-driving cars, automated medical diagnoses, and examples from the authors' collaborative research.

11.4MLNov 7, 2017
Interpreting Convolutional Neural Networks Through Compression

Reza Abbasi-Asl, Bin Yu

Convolutional neural networks (CNNs) achieve state-of-the-art performance in a wide variety of tasks in computer vision. However, interpreting CNNs still remains a challenge. This is mainly due to the large number of parameters in these networks. Here, we investigate the role of compression and particularly pruning filters in the interpretation of CNNs. We exploit our recently-proposed greedy structural compression scheme that prunes filters in a trained CNN. In our compression, the filter importance index is defined as the classification accuracy reduction (CAR) of the network after pruning that filter. The filters are then iteratively pruned based on the CAR index. We demonstrate the interpretability of CAR-compressed CNNs by showing that our algorithm prunes filters with visually redundant pattern selectivity. Specifically, we show the importance of shape-selective filters for object recognition, as opposed to color-selective filters. Out of top 20 CAR-pruned filters in AlexNet, 17 of them in the first layer and 14 of them in the second layer are color-selective filters. Finally, we introduce a variant of our CAR importance index that quantifies the importance of each image class to each CNN filter. We show that the most and the least important class labels present a meaningful interpretation of each filter that is consistent with the visualized pattern selectivity of that filter.

14.0CVMay 20, 2017Code
Structural Compression of Convolutional Neural Networks

Reza Abbasi-Asl, Bin Yu

Deep convolutional neural networks (CNNs) have been successful in many tasks in machine vision, however, millions of weights in the form of thousands of convolutional filters in CNNs makes them difficult for human intepretation or understanding in science. In this article, we introduce CAR, a greedy structural compression scheme to obtain smaller and more interpretable CNNs, while achieving close to original accuracy. The compression is based on pruning filters with the least contribution to the classification accuracy. We demonstrate the interpretability of CAR-compressed CNNs by showing that our algorithm prunes filters with visually redundant functionalities such as color filters. These compressed networks are easier to interpret because they retain the filter diversity of uncompressed networks with order of magnitude less filters. Finally, a variant of CAR is introduced to quantify the importance of each image category to each CNN filter. Specifically, the most and the least important class labels are shown to be meaningful interpretations of each filter.

6.3MLOct 23, 2016
Formulas for Counting the Sizes of Markov Equivalence Classes of Directed Acyclic Graphs

Yangbo He, Bin Yu

The sizes of Markov equivalence classes of directed acyclic graphs play important roles in measuring the uncertainty and complexity in causal learning. A Markov equivalence class can be represented by an essential graph and its undirected subgraphs determine the size of the class. In this paper, we develop a method to derive the formulas for counting the sizes of Markov equivalence classes. We first introduce a new concept of core graph. The size of a Markov equivalence class of interest is a polynomial of the number of vertices given its core graph. Then, we discuss the recursive and explicit formula of the polynomial, and provide an algorithm to derive the size formula via symbolic computation for any given core graph. The proposed size formula derivation sheds light on the relationships between the size of a Markov equivalence class and its representation graph, and makes size counting efficient, even when the essential graphs contain non-sparse undirected subgraphs.

5.1MESep 17, 2015
Optimal Subsampling Approaches for Large Sample Linear Regression

Rong Zhu, Ping Ma, Michael W. Mahoney et al.

A significant hurdle for analyzing large sample data is the lack of effective statistical computing and inference methods. An emerging powerful approach for analyzing large sample data is subsampling, by which one takes a random subsample from the original full sample and uses it as a surrogate for subsequent computation and estimation. In this paper, we study subsampling methods under two scenarios: approximating the full sample ordinary least-square (OLS) estimator and estimating the coefficients in linear regression. We present two algorithms, weighted estimation algorithm and unweighted estimation algorithm, and analyze asymptotic behaviors of their resulting subsample estimators under general conditions. For the weighted estimation algorithm, we propose a criterion for selecting the optimal sampling probability by making use of the asymptotic results. On the basis of the criterion, we provide two novel subsampling methods, the optimal subsampling and the predictor- length subsampling methods. The predictor-length subsampling method is based on the L2 norm of predictors rather than leverage scores. Its computational cost is scalable. For unweighted estimation algorithm, we show that its resulting subsample estimator is not consistent to the full sample OLS estimator. However, it has better performance than the weighted estimation algorithm for estimating the coefficients. Simulation studies and a real data example are used to demonstrate the effectiveness of our proposed subsampling methods.

17.7MLNov 15, 2014
Error Rate Bounds and Iterative Weighted Majority Voting for Crowdsourcing

Hongwei Li, Bin Yu

Crowdsourcing has become an effective and popular tool for human-powered computation to label large datasets. Since the workers can be unreliable, it is common in crowdsourcing to assign multiple workers to one task, and to aggregate the labels in order to obtain results of high quality. In this paper, we provide finite-sample exponential bounds on the error rate (in probability and in expectation) of general aggregation rules under the Dawid-Skene crowdsourcing model. The bounds are derived for multi-class labeling, and can be used to analyze many aggregation methods, including majority voting, weighted majority voting and the oracle Maximum A Posteriori (MAP) rule. We show that the oracle MAP rule approximately optimizes our upper bound on the mean error rate of weighted majority voting in certain setting. We propose an iterative weighted majority voting (IWMV) method that optimizes the error rate bound and approximates the oracle MAP rule. Its one step version has a provable theoretical guarantee on the error rate. The IWMV method is intuitive and computationally simple. Experimental results on simulated and real data show that IWMV performs at least on par with the state-of-the-art methods, and it has a much lower computational cost (around one hundred times faster) than the state-of-the-art methods.

35.2STAug 9, 2014
Statistical guarantees for the EM algorithm: From population to sample-based analysis

Sivaraman Balakrishnan, Martin J. Wainwright, Bin Yu

We develop a general framework for proving rigorous guarantees on the performance of the EM algorithm and a variant known as gradient EM. Our analysis is divided into two parts: a treatment of these algorithms at the population level (in the limit of infinite data), followed by results that apply to updates based on a finite set of samples. First, we characterize the domain of attraction of any global maximizer of the population likelihood. This characterization is based on a novel view of the EM updates as a perturbed form of likelihood ascent, or in parallel, of the gradient EM updates as a perturbed form of standard gradient ascent. Leveraging this characterization, we then provide non-asymptotic guarantees on the EM and gradient EM algorithms when applied to a finite set of samples. We develop consequences of our general theory for three canonical examples of incomplete-data problems: mixture of Gaussians, mixture of regressions, and linear regression with covariates missing completely at random. In each case, our theory guarantees that with a suitable initialization, a relatively small number of EM (or gradient EM) steps will yield (with high probability) an estimate that is within statistical error of the MLE. We provide simulations to confirm this theoretically predicted behavior.

5.2CLApr 29, 2014
Concise comparative summaries (CCS) of large text corpora with a human experiment

Jinzhu Jia, Luke Miratrix, Bin Yu et al.

In this paper we propose a general framework for topic-specific summarization of large text corpora and illustrate how it can be used for the analysis of news databases. Our framework, concise comparative summarization (CCS), is built on sparse classification methods. CCS is a lightweight and flexible tool that offers a compromise between simple word frequency based methods currently in wide use and more heavyweight, model-intensive methods such as latent Dirichlet allocation (LDA). We argue that sparse methods have much to offer for text analysis and hope CCS opens the door for a new branch of research in this important field. For a particular topic of interest (e.g., China or energy), CSS automatically labels documents as being either on- or off-topic (usually via keyword search), and then uses sparse classification methods to predict these labels with the high-dimensional counts of all the other words and phrases in the documents. The resulting small set of phrases found as predictive are then harvested as the summary. To validate our tool, we, using news articles from the New York Times international section, designed and conducted a human survey to compare the different summarizers with human understanding. We demonstrate our approach with two case studies, a media analysis of the framing of "Egypt" in the New York Times throughout the Arab Spring and an informal comparison of the New York Times' and Wall Street Journal's coverage of "energy." Overall, we find that the Lasso with $L^2$ normalization can be effectively and usefully used to summarize large corpora, regardless of document size.

7.5MLJul 10, 2013
Error Rate Bounds in Crowdsourcing Models

Hongwei Li, Bin Yu, Dengyong Zhou

Crowdsourcing is an effective tool for human-powered computation on many tasks challenging for computers. In this paper, we provide finite-sample exponential bounds on the error rate (in probability and in expectation) of hyperplane binary labeling rules under the Dawid-Skene crowdsourcing model. The bounds can be applied to analyze many common prediction methods, including the majority voting and weighted majority voting. These bound results could be useful for controlling the error rate and designing better algorithms. We show that the oracle Maximum A Posterior (MAP) rule approximately optimizes our upper bound on the mean error rate for any hyperplane binary labeling rule, and propose a simple data-driven weighted majority voting (WMV) rule (called one-step WMV) that attempts to approximate the oracle MAP and has a provable theoretical guarantee on the error rate. Moreover, we use simulated and real data to demonstrate that the data-driven EM-MAP rule is a good approximation to the oracle MAP rule, and to demonstrate that the mean error rate of the data-driven EM-MAP rule is also bounded by the mean error rate bound of the oracle MAP rule with estimated parameters plugging into the bound.

9.3MLSep 26, 2012
Reversible MCMC on Markov equivalence classes of sparse directed acyclic graphs

Yangbo He, Jinzhu Jia, Bin Yu

Graphical models are popular statistical tools which are used to represent dependent or causal complex systems. Statistically equivalent causal or directed graphical models are said to belong to a Markov equivalent class. It is of great interest to describe and understand the space of such classes. However, with currently known algorithms, sampling over such classes is only feasible for graphs with fewer than approximately 20 vertices. In this paper, we design reversible irreducible Markov chains on the space of Markov equivalent classes by proposing a perfect set of operators that determine the transitions of the Markov chain. The stationary distribution of a proposed Markov chain has a closed form and can be computed easily. Specifically, we construct a concrete perfect set of operators on sparse Markov equivalence classes by introducing appropriate conditions on each possible operator. Algorithms and their accelerated versions are provided to efficiently generate Markov chains and to explore properties of Markov equivalence classes of sparse directed acyclic graphs (DAGs) with thousands of vertices. We find experimentally that in most Markov equivalence classes of sparse DAGs, (1) most edges are directed, (2) most undirected subgraphs are small and (3) the number of these undirected subgraphs grows approximately linearly with the number of vertices. The article contains supplement arXiv:1303.0632, http://dx.doi.org/10.1214/13-AOS1125SUPP