Yoonho Lee

LG
h-index21
31papers
2,570citations
Novelty56%
AI Score59

31 Papers

LGNov 25, 2022Code
Wild-Time: A Benchmark of in-the-Wild Distribution Shift over Time

Huaxiu Yao, Caroline Choi, Bochuan Cao et al. · stanford

Distribution shift occurs when the test distribution differs from the training distribution, and it can considerably degrade performance of machine learning models deployed in the real world. Temporal shifts -- distribution shifts arising from the passage of time -- often occur gradually and have the additional structure of timestamp metadata. By leveraging timestamp metadata, models can potentially learn from trends in past distribution shifts and extrapolate into the future. While recent works have studied distribution shifts, temporal shifts remain underexplored. To address this gap, we curate Wild-Time, a benchmark of 5 datasets that reflect temporal distribution shifts arising in a variety of real-world applications, including patient prognosis and news classification. On these datasets, we systematically benchmark 13 prior approaches, including methods in domain generalization, continual learning, self-supervised learning, and ensemble learning. We use two evaluation strategies: evaluation with a fixed time split (Eval-Fix) and evaluation with a data stream (Eval-Stream). Eval-Fix, our primary evaluation strategy, aims to provide a simple evaluation protocol, while Eval-Stream is more realistic for certain real-world applications. Under both evaluation strategies, we observe an average performance drop of 20% from in-distribution to out-of-distribution data. Existing methods are unable to close this gap. Code is available at https://wild-time.github.io/.

CLJan 26, 2023
DetectGPT: Zero-Shot Machine-Generated Text Detection using Probability Curvature

Eric Mitchell, Yoonho Lee, Alexander Khazatsky et al. · stanford

The increasing fluency and widespread usage of large language models (LLMs) highlight the desirability of corresponding tools aiding detection of LLM-generated text. In this paper, we identify a property of the structure of an LLM's probability function that is useful for such detection. Specifically, we demonstrate that text sampled from an LLM tends to occupy negative curvature regions of the model's log probability function. Leveraging this observation, we then define a new curvature-based criterion for judging if a passage is generated from a given LLM. This approach, which we call DetectGPT, does not require training a separate classifier, collecting a dataset of real or generated passages, or explicitly watermarking generated text. It uses only log probabilities computed by the model of interest and random perturbations of the passage from another generic pre-trained language model (e.g., T5). We find DetectGPT is more discriminative than existing zero-shot methods for model sample detection, notably improving detection of fake news articles generated by 20B parameter GPT-NeoX from 0.81 AUROC for the strongest zero-shot baseline to 0.95 AUROC for DetectGPT. See https://ericmitchell.ai/detectgpt for code, data, and other project information.

LGFeb 10, 2023
Project and Probe: Sample-Efficient Domain Adaptation by Interpolating Orthogonal Features

Annie S. Chen, Yoonho Lee, Amrith Setlur et al. · cmu

Transfer learning with a small amount of target data is an effective and common approach to adapting a pre-trained model to distribution shifts. In some situations, target data labels may be expensive to obtain, so we may only have access to a limited number of target data points. To make the most of a very small target dataset, we propose a lightweight, sample-efficient approach that learns a diverse set of features and adapts to a target distribution by interpolating these features. Our approach, Project and Probe (Pro$^2$), first learns a linear projection that maps a pre-trained embedding onto orthogonal directions while being predictive of labels in the source dataset. The goal of this step is to learn a variety of predictive features, so that at least some of them remain useful after distribution shift. Pro$^2$ then learns a linear classifier on top of these projected features using a small target dataset. Theoretically, we find that Pro$^2$ results in more sample-efficient generalization by inducing a favorable bias-variance tradeoff. Our experiments on four datasets, with multiple distribution shift settings for each, show that Pro$^2$ improves performance by 5-15% when given limited target data compared to prior methods such as standard linear probing.

LGJun 19, 2023
Confidence-Based Model Selection: When to Take Shortcuts for Subpopulation Shifts

Annie S. Chen, Yoonho Lee, Amrith Setlur et al. · cmu

Effective machine learning models learn both robust features that directly determine the outcome of interest (e.g., an object with wheels is more likely to be a car), and shortcut features (e.g., an object on a road is more likely to be a car). The latter can be a source of error under distributional shift, when the correlations change at test-time. The prevailing sentiment in the robustness literature is to avoid such correlative shortcut features and learn robust predictors. However, while robust predictors perform better on worst-case distributional shifts, they often sacrifice accuracy on majority subpopulations. In this paper, we argue that shortcut features should not be entirely discarded. Instead, if we can identify the subpopulation to which an input belongs, we can adaptively choose among models with different strengths to achieve high performance on both majority and minority subpopulations. We propose COnfidence-baSed MOdel Selection (CosMoS), where we observe that model confidence can effectively guide model selection. Notably, CosMoS does not require any target labels or group annotations, either of which may be difficult to obtain or unavailable. We evaluate CosMoS on four datasets with spurious correlations, each with multiple test sets with varying levels of data distribution shift. We find that CosMoS achieves 2-5% lower average regret across all subpopulations, compared to using only robust predictors or other model aggregation methods.

LGOct 20, 2022
Surgical Fine-Tuning Improves Adaptation to Distribution Shifts

Yoonho Lee, Annie S. Chen, Fahim Tajwar et al.

A common approach to transfer learning under distribution shift is to fine-tune the last few layers of a pre-trained model, preserving learned features while also adapting to the new task. This paper shows that in such settings, selectively fine-tuning a subset of layers (which we term surgical fine-tuning) matches or outperforms commonly used fine-tuning approaches. Moreover, the type of distribution shift influences which subset is more effective to tune: for example, for image corruptions, fine-tuning only the first few layers works best. We validate our findings systematically across seven real-world data tasks spanning three types of distribution shifts. Theoretically, we prove that for two-layer neural networks in an idealized setting, first-layer tuning can outperform fine-tuning all layers. Intuitively, fine-tuning more parameters on a small target dataset can cause information learned during pre-training to be forgotten, and the relevant information depends on the type of shift.

LGJun 8, 2023
Conservative Prediction via Data-Driven Confidence Minimization

Caroline Choi, Fahim Tajwar, Yoonho Lee et al.

In safety-critical applications of machine learning, it is often desirable for a model to be conservative, abstaining from making predictions on unknown inputs which are not well-represented in the training data. However, detecting unknown examples is challenging, as it is impossible to anticipate all potential inputs at test time. To address this, prior work (Hendrycks et al., 2018) minimizes model confidence on an auxiliary outlier dataset carefully curated to be disjoint from the training distribution. We theoretically analyze the choice of auxiliary dataset for confidence minimization, revealing two actionable insights: (1) if the auxiliary set contains unknown examples similar to those seen at test time, confidence minimization leads to provable detection of unknown test examples, and (2) if the first condition is satisfied, it is unnecessary to filter out known examples for out-of-distribution (OOD) detection. Motivated by these guidelines, we propose the Data-Driven Confidence Minimization (DCM) framework, which minimizes confidence on an uncertainty dataset. We apply DCM to two problem settings in which conservative prediction is paramount -- selective classification and OOD detection -- and provide a realistic way to gather uncertainty data for each setting. In our experiments, DCM consistently outperforms existing selective classification approaches on 4 datasets when tested on unseen distributions and outperforms state-of-the-art OOD detection methods on 12 ID-OOD dataset pairs, reducing FPR (at TPR $95\%$) by $6.3\%$ and $58.1\%$ on CIFAR-10 and CIFAR-100 compared to Outlier Exposure.

AIMar 30
Meta-Harness: End-to-End Optimization of Model Harnesses

Yoonho Lee, Roshen Nair, Qizheng Zhang et al.

The performance of large language model (LLM) systems depends not only on model weights, but also on their harness: the code that determines what information to store, retrieve, and present to the model. Yet harnesses are still designed largely by hand, and existing text optimizers are poorly matched to this setting because they compress feedback too aggressively. We introduce Meta-Harness, an outer-loop system that searches over harness code for LLM applications. It uses an agentic proposer that accesses the source code, scores, and execution traces of all prior candidates through a filesystem. On online text classification, Meta-Harness improves over a state-of-the-art context management system by 7.7 points while using 4x fewer context tokens. On retrieval-augmented math reasoning, a single discovered harness improves accuracy on 200 IMO-level problems by 4.7 points on average across five held-out models. On agentic coding, discovered harnesses surpass the best hand-engineered baselines on TerminalBench-2. Together, these results show that richer access to prior experience can enable automated harness engineering.

ROAug 30, 2024
Bidirectional Decoding: Improving Action Chunking via Guided Test-Time Sampling

Yuejiang Liu, Jubayer Ibn Hamid, Annie Xie et al.

Predicting and executing a sequence of actions without intermediate replanning, known as action chunking, is increasingly used in robot learning from human demonstrations. Yet, its effects on the learned policy remain inconsistent: some studies find it crucial for achieving strong results, while others observe decreased performance. In this paper, we first dissect how action chunking impacts the divergence between a learner and a demonstrator. We find that action chunking allows the learner to better capture the temporal dependencies in demonstrations but at the cost of reduced reactivity to unexpected states. To address this tradeoff, we propose Bidirectional Decoding (BID), a test-time inference algorithm that bridges action chunking with closed-loop adaptation. At each timestep, BID samples multiple candidate predictions and searches for the optimal one based on two criteria: (i) backward coherence, which favors samples that align with previous decisions; (ii) forward contrast, which seeks samples of high likelihood for future plans. By coupling decisions within and across action chunks, BID promotes both long-term consistency and short-term reactivity. Experimental results show that our method boosts the performance of two state-of-the-art generative policies across seven simulation benchmarks and two real-world tasks. Code and videos are available at https://bid-robot.github.io.

LGSep 29, 2024
Calibrating Language Models with Adaptive Temperature Scaling

Johnathan Xie, Annie S. Chen, Yoonho Lee et al.

The effectiveness of large language models (LLMs) is not only measured by their ability to generate accurate outputs but also by their calibration-how well their confidence scores reflect the probability of their outputs being correct. While unsupervised pre-training has been shown to yield LLMs with well-calibrated conditional probabilities, recent studies have shown that after fine-tuning with reinforcement learning from human feedback (RLHF), the calibration of these models degrades significantly. In this work, we introduce Adaptive Temperature Scaling (ATS), a post-hoc calibration method that predicts a temperature scaling parameter for each token prediction. The predicted temperature values adapt based on token-level features and are fit over a standard supervised fine-tuning (SFT) dataset. The adaptive nature of ATS addresses the varying degrees of calibration shift that can occur after RLHF fine-tuning. ATS improves calibration by over 10-50% across three downstream natural language evaluation benchmarks compared to prior calibration methods and does not impede performance improvements from RLHF.

LGOct 12, 2022
On Divergence Measures for Bayesian Pseudocoresets

Balhae Kim, Jungwon Choi, Seanie Lee et al.

A Bayesian pseudocoreset is a small synthetic dataset for which the posterior over parameters approximates that of the original dataset. While promising, the scalability of Bayesian pseudocoresets is not yet validated in realistic problems such as image classification with deep neural networks. On the other hand, dataset distillation methods similarly construct a small dataset such that the optimization using the synthetic dataset converges to a solution with performance competitive with optimization using full data. Although dataset distillation has been empirically verified in large-scale settings, the framework is restricted to point estimates, and their adaptation to Bayesian inference has not been explored. This paper casts two representative dataset distillation algorithms as approximations to methods for constructing pseudocoresets by minimizing specific divergence measures: reverse KL divergence and Wasserstein distance. Furthermore, we provide a unifying view of such divergence measures in Bayesian pseudocoreset construction. Finally, we propose a novel Bayesian pseudocoreset algorithm based on minimizing forward KL divergence. Our empirical results demonstrate that the pseudocoresets constructed from these methods reflect the true posterior even in high-dimensional Bayesian inference problems.

LGMar 6, 2025Code
Subgraph Federated Learning for Local Generalization

Sungwon Kim, Yoonho Lee, Yunhak Oh et al.

Federated Learning (FL) on graphs enables collaborative model training to enhance performance without compromising the privacy of each client. However, existing methods often overlook the mutable nature of graph data, which frequently introduces new nodes and leads to shifts in label distribution. Since they focus solely on performing well on each client's local data, they are prone to overfitting to their local distributions (i.e., local overfitting), which hinders their ability to generalize to unseen data with diverse label distributions. In contrast, our proposed method, FedLoG, effectively tackles this issue by mitigating local overfitting. Our model generates global synthetic data by condensing the reliable information from each class representation and its structural information across clients. Using these synthetic data as a training set, we alleviate the local overfitting problem by adaptively generalizing the absent knowledge within each local dataset. This enhances the generalization capabilities of local models, enabling them to handle unseen data effectively. Our model outperforms baselines in our proposed experimental settings, which are designed to measure generalization power to unseen data in practical scenarios. Our code is available at https://github.com/sung-won-kim/FedLoG

LGNov 11, 2025
Feedback Descent: Open-Ended Text Optimization via Pairwise Comparison

Yoonho Lee, Joseph Boen, Chelsea Finn

We introduce \textit{Feedback Descent}, a framework that optimizes text artifacts -- prompts, code, and molecules -- through structured textual feedback, rather than relying solely on scalar rewards. By preserving detailed critiques instead of compressing them to binary preferences, Feedback Descent widens the information bottleneck in preference learning, enabling directed optimization in text space rather than weight space. We show that in-context learning can transform structured feedback into gradient-like directional information, enabling targeted edits. Unlike prior approaches that collapse judgments into single bits, our evaluators pair each comparison with textual feedback, which functions as high-bandwidth supervision. The iteration loop is done purely at inference time, without modifying any model weights, and is task-agnostic. We evaluate Feedback Descent on three diverse domains and find that it outperforms state-of-the-art prompt optimization (GEPA), reinforcement learning methods (GRPO, REINVENT), and even specialized graph-based molecular optimizers. In the DOCKSTRING molecule discovery benchmark, Feedback Descent identifies novel drug-like molecules surpassing the $99.9$th percentile of a database with more than $260{,}000$ compounds across six protein targets.

MLMay 28, 2019Code
Learning Dynamics of Attention: Human Prior for Interpretable Machine Reasoning

Wonjae Kim, Yoonho Lee

Without relevant human priors, neural networks may learn uninterpretable features. We propose Dynamics of Attention for Focus Transition (DAFT) as a human prior for machine reasoning. DAFT is a novel method that regularizes attention-based reasoning by modelling it as a continuous dynamical system using neural ordinary differential equations. As a proof of concept, we augment a state-of-the-art visual reasoning model with DAFT. Our experiments reveal that applying DAFT yields similar performance to the original model while using fewer reasoning steps, showing that it implicitly learns to skip unnecessary steps. We also propose a new metric, Total Length of Transition (TLT), which represents the effective reasoning step size by quantifying how much a given model's focus drifts while reasoning about a question. We show that adding DAFT results in lower TLT, demonstrating that our method indeed obeys the human prior towards shorter reasoning paths in addition to producing more interpretable attention maps. Our code is available at https://github.com/kakao/DAFT.

LGFeb 6, 2024
Clarify: Improving Model Robustness With Natural Language Corrections

Yoonho Lee, Michelle S. Lam, Helena Vasconcelos et al.

The standard way to teach models is by feeding them lots of data. However, this approach often teaches models incorrect ideas because they pick up on misleading signals in the data. To prevent such misconceptions, we must necessarily provide additional information beyond the training data. Prior methods incorporate additional instance-level supervision, such as labels for misleading features or additional labels for debiased data. However, such strategies require a large amount of labeler effort. We hypothesize that people are good at providing textual feedback at the concept level, a capability that existing teaching frameworks do not leverage. We propose Clarify, a novel interface and method for interactively correcting model misconceptions. Through Clarify, users need only provide a short text description of a model's consistent failure patterns. Then, in an entirely automated way, we use such descriptions to improve the training process. Clarify is the first end-to-end system for user model correction. Our user studies show that non-expert users can successfully describe model misconceptions via Clarify, leading to increased worst-case performance in two datasets. We additionally conduct a case study on a large-scale image dataset, ImageNet, using Clarify to find and rectify 31 novel hard subpopulations.

LGDec 11, 2024
Test-Time Alignment via Hypothesis Reweighting

Yoonho Lee, Jonathan Williams, Henrik Marklund et al. · stanford

Large pretrained models often struggle with underspecified tasks -- situations where the training data does not fully define the desired behavior. For example, chatbots must handle diverse and often conflicting user preferences, requiring adaptability to various user needs. We propose a novel framework to address the general challenge of aligning models to test-time user intent, which is rarely fully specified during training. Our approach involves training an efficient ensemble, i.e., a single neural network with multiple prediction heads, each representing a different function consistent with the training data. Our main contribution is HyRe, a simple adaptation technique that dynamically reweights ensemble members at test time using a small set of labeled examples from the target distribution, which can be labeled in advance or actively queried from a larger unlabeled pool. By leveraging recent advances in scalable ensemble training, our method scales to large pretrained models, with computational costs comparable to fine-tuning a single model. We empirically validate HyRe in several underspecified scenarios, including personalization tasks and settings with distribution shifts. Additionally, with just five preference pairs from each target distribution, the same ensemble adapted via HyRe outperforms the prior state-of-the-art 2B-parameter reward model accuracy across 18 evaluation distributions.

LGFeb 22, 2024
Self-Guided Masked Autoencoders for Domain-Agnostic Self-Supervised Learning

Johnathan Xie, Yoonho Lee, Annie S. Chen et al.

Self-supervised learning excels in learning representations from large amounts of unlabeled data, demonstrating success across multiple data modalities. Yet, extending self-supervised learning to new modalities is non-trivial because the specifics of existing methods are tailored to each domain, such as domain-specific augmentations which reflect the invariances in the target task. While masked modeling is promising as a domain-agnostic framework for self-supervised learning because it does not rely on input augmentations, its mask sampling procedure remains domain-specific. We present Self-guided Masked Autoencoders (SMA), a fully domain-agnostic masked modeling method. SMA trains an attention based model using a masked modeling objective, by learning masks to sample without any domain-specific assumptions. We evaluate SMA on three self-supervised learning benchmarks in protein biology, chemical property prediction, and particle physics. We find SMA is capable of learning representations without domain-specific knowledge and achieves state-of-the-art performance on these three benchmarks.

AIOct 2, 2025
RLAD: Training LLMs to Discover Abstractions for Solving Reasoning Problems

Yuxiao Qu, Anikait Singh, Yoonho Lee et al. · cmu, stanford

Reasoning requires going beyond pattern matching or memorization of solutions to identify and implement "algorithmic procedures" that can be used to deduce answers to hard problems. Doing so requires realizing the most relevant primitives, intermediate results, or shared procedures, and building upon them. While RL post-training on long chains of thought ultimately aims to uncover this kind of algorithmic behavior, most reasoning traces learned by large models fail to consistently capture or reuse procedures, instead drifting into verbose and degenerate exploration. To address more effective reasoning, we introduce reasoning abstractions: concise natural language descriptions of procedural and factual knowledge that guide the model toward learning successful reasoning. We train models to be capable of proposing multiple abstractions given a problem, followed by RL that incentivizes building a solution while using the information provided by these abstractions. This results in a two-player RL training paradigm, abbreviated as RLAD, that jointly trains an abstraction generator and a solution generator. This setup effectively enables structured exploration, decouples learning signals of abstraction proposal and solution generation, and improves generalization to harder problems. We also show that allocating more test-time compute to generating abstractions is more beneficial for performance than generating more solutions at large test budgets, illustrating the role of abstractions in guiding meaningful exploration.

LGOct 18, 2025
Disentangling Hyperedges through the Lens of Category Theory

Yoonho Lee, Junseok Lee, Sangwoo Seo et al.

Despite the promising results of disentangled representation learning in discovering latent patterns in graph-structured data, few studies have explored disentanglement for hypergraph-structured data. Integrating hyperedge disentanglement into hypergraph neural networks enables models to leverage hidden hyperedge semantics, such as unannotated relations between nodes, that are associated with labels. This paper presents an analysis of hyperedge disentanglement from a category-theoretical perspective and proposes a novel criterion for disentanglement derived from the naturality condition. Our proof-of-concept model experimentally showed the potential of the proposed criterion by successfully capturing functional relations of genes (nodes) in genetic pathways (hyperedges).

LGJul 18, 2025
Target Circuit Matching in Large-Scale Netlists using GNN-Based Region Prediction

Sangwoo Seo, Jimin Seo, Yoonho Lee et al.

Subgraph matching plays an important role in electronic design automation (EDA) and circuit verification. Traditional rule-based methods have limitations in generalizing to arbitrary target circuits. Furthermore, node-to-node matching approaches tend to be computationally inefficient, particularly for large-scale circuits. Deep learning methods have emerged as a potential solution to address these challenges, but existing models fail to efficiently capture global subgraph embeddings or rely on inefficient matching matrices, which limits their effectiveness for large circuits. In this paper, we propose an efficient graph matching approach that utilizes Graph Neural Networks (GNNs) to predict regions of high probability for containing the target circuit. Specifically, we construct various negative samples to enable GNNs to accurately learn the presence of target circuits and develop an approach to directly extracting subgraph embeddings from the entire circuit, which captures global subgraph information and addresses the inefficiency of applying GNNs to all candidate subgraphs. Extensive experiments demonstrate that our approach significantly outperforms existing methods in terms of time efficiency and target region prediction, offering a scalable and effective solution for subgraph matching in large-scale circuits.

LGJun 19, 2024
Self-Explainable Temporal Graph Networks based on Graph Information Bottleneck

Sangwoo Seo, Sungwon Kim, Jihyeong Jung et al.

Temporal Graph Neural Networks (TGNN) have the ability to capture both the graph topology and dynamic dependencies of interactions within a graph over time. There has been a growing need to explain the predictions of TGNN models due to the difficulty in identifying how past events influence their predictions. Since the explanation model for a static graph cannot be readily applied to temporal graphs due to its inability to capture temporal dependencies, recent studies proposed explanation models for temporal graphs. However, existing explanation models for temporal graphs rely on post-hoc explanations, requiring separate models for prediction and explanation, which is limited in two aspects: efficiency and accuracy of explanation. In this work, we propose a novel built-in explanation framework for temporal graphs, called Self-Explainable Temporal Graph Networks based on Graph Information Bottleneck (TGIB). TGIB provides explanations for event occurrences by introducing stochasticity in each temporal event based on the Information Bottleneck theory. Experimental results demonstrate the superiority of TGIB in terms of both the link prediction performance and explainability compared to state-of-the-art methods. This is the first work that simultaneously performs prediction and explanation for temporal graphs in an end-to-end manner.

LGJan 18, 2024
AutoFT: Learning an Objective for Robust Fine-Tuning

Caroline Choi, Yoonho Lee, Annie Chen et al.

Foundation models encode rich representations that can be adapted to downstream tasks by fine-tuning. However, fine-tuning a model on one data distribution often degrades performance under distribution shifts. Current approaches to robust fine-tuning use hand-crafted regularization techniques to constrain the fine-tuning process towards the pretrained model. Yet, it is hard to specify how to adapt relevant characteristics of the foundation model during fine-tuning, as this depends on how the pre-training, fine-tuning, and test data distributions relate to each other. We propose AutoFT, a data-driven approach for robust fine-tuning. Given a task, AutoFT searches for a fine-tuning procedure that enhances out-of-distribution (OOD) generalization. Specifically, AutoFT uses bi-level optimization to search for an objective function and hyperparameters that maximize post-adaptation performance on a small OOD validation set. We evaluate AutoFT on nine natural distribution shifts. Our experiments show that AutoFT significantly improves generalization to OOD inputs, outperforming existing robust fine-tuning methods. Notably, AutoFT achieves a new state-of-the-art on the WILDS iWildCam and FMoW benchmarks, outperforming the previous best methods by $6.0\%$ and $1.5\%$, respectively.

LGFeb 7, 2022
Diversify and Disambiguate: Learning From Underspecified Data

Yoonho Lee, Huaxiu Yao, Chelsea Finn

Many datasets are underspecified: there exist multiple equally viable solutions to a given task. Underspecification can be problematic for methods that learn a single hypothesis because different functions that achieve low training loss can focus on different predictive features and thus produce widely varying predictions on out-of-distribution data. We propose DivDis, a simple two-stage framework that first learns a diverse collection of hypotheses for a task by leveraging unlabeled data from the test distribution. We then disambiguate by selecting one of the discovered hypotheses using minimal additional supervision, in the form of additional labels or inspection of function visualization. We demonstrate the ability of DivDis to find hypotheses that use robust features in image classification and natural language processing problems with underspecification.

LGOct 27, 2021
Diversity Matters When Learning From Ensembles

Giung Nam, Jongmin Yoon, Yoonho Lee et al.

Deep ensembles excel in large-scale image classification tasks both in terms of prediction accuracy and calibration. Despite being simple to train, the computation and memory cost of deep ensembles limits their practicability. While some recent works propose to distill an ensemble model into a single model to reduce such costs, there is still a performance gap between the ensemble and distilled models. We propose a simple approach for reducing this gap, i.e., making the distilled performance close to the full ensemble. Our key assumption is that a distilled model should absorb as much function diversity inside the ensemble as possible. We first empirically show that the typical distillation procedure does not effectively transfer such diversity, especially for complex models that achieve near-zero training error. To fix this, we propose a perturbation strategy for distillation that reveals diversity by seeking inputs for which ensemble member outputs disagree. We empirically show that a model distilled with such perturbed samples indeed exhibits enhanced diversity, leading to improved performance.

LGJul 5, 2021
On The Distribution of Penultimate Activations of Classification Networks

Minkyo Seo, Yoonho Lee, Suha Kwak

This paper studies probability distributions of penultimate activations of classification networks. We show that, when a classification network is trained with the cross-entropy loss, its final classification layer forms a Generative-Discriminative pair with a generative classifier based on a specific distribution of penultimate activations. More importantly, the distribution is parameterized by the weights of the final fully-connected layer, and can be considered as a generative model that synthesizes the penultimate activations without feeding input data. We empirically demonstrate that this generative model enables stable knowledge distillation in the presence of domain shift, and can transfer knowledge from a classifier to variational autoencoders and generative adversarial networks for class-conditional image generation.

MLOct 29, 2020
Amortized Probabilistic Detection of Communities in Graphs

Yueqi Wang, Yoonho Lee, Pallab Basu et al.

Learning community structures in graphs has broad applications across scientific domains. While graph neural networks (GNNs) have been successful in encoding graph structures, existing GNN-based methods for community detection are limited by requiring knowledge of the number of communities in advance, in addition to lacking a proper probabilistic formulation to handle uncertainty. We propose a simple framework for amortized community detection, which addresses both of these issues by combining the expressive power of GNNs with recent methods for amortized clustering. Our models consist of a graph representation backbone that extracts structural information and an amortized clustering network that naturally handles variable numbers of clusters. Both components combine into well-defined models of the posterior distribution of graph communities and are jointly optimized given labeled graphs. At inference time, the models yield parallel samples from the posterior of community labels, quantifying uncertainty in a principled way. We evaluate several models from our framework on synthetic and real datasets, and demonstrate improved performance compared to previous methods. As a separate contribution, we extend recent amortized probabilistic clustering architectures by adding attention modules, which yield further improvements on community detection tasks.

LGAug 7, 2020
Bootstrapping Neural Processes

Juho Lee, Yoonho Lee, Jungtaek Kim et al.

Unlike in the traditional statistical modeling for which a user typically hand-specify a prior, Neural Processes (NPs) implicitly define a broad class of stochastic processes with neural networks. Given a data stream, NP learns a stochastic process that best describes the data. While this "data-driven" way of learning stochastic processes has proven to handle various types of data, NPs still rely on an assumption that uncertainty in stochastic processes is modeled by a single latent variable, which potentially limits the flexibility. To this end, we propose the Boostrapping Neural Process (BNP), a novel extension of the NP family using the bootstrap. The bootstrap is a classical data-driven technique for estimating uncertainty, which allows BNP to learn the stochasticity in NPs without assuming a particular form. We demonstrate the efficacy of BNP on various types of data and its robustness in the presence of model-data mismatch.

LGAug 7, 2020
Neural Complexity Measures

Yoonho Lee, Juho Lee, Sung Ju Hwang et al.

While various complexity measures for deep neural networks exist, specifying an appropriate measure capable of predicting and explaining generalization in deep networks has proven challenging. We propose Neural Complexity (NC), a meta-learning framework for predicting generalization. Our model learns a scalar complexity measure through interactions with many heterogeneous tasks in a data-driven way. The trained NC model can be added to the standard training loss to regularize any task learner in a standard supervised learning scenario. We contrast NC's approach against existing manually-designed complexity measures and other meta-learning models, and we validate NC's performance on multiple regression and classification tasks

LGSep 30, 2019
Deep Amortized Clustering

Juho Lee, Yoonho Lee, Yee Whye Teh

We propose a deep amortized clustering (DAC), a neural architecture which learns to cluster datasets efficiently using a few forward passes. DAC implicitly learns what makes a cluster, how to group data points into clusters, and how to count the number of clusters in datasets. DAC is meta-learned using labelled datasets for training, a process distinct from traditional clustering algorithms which usually require hand-specified prior knowledge about cluster shapes/structures. We empirically show, on both synthetic and image data, that DAC can efficiently and accurately cluster new datasets coming from the same distribution used to generate training datasets.

MLMay 28, 2019
Discrete Infomax Codes for Supervised Representation Learning

Yoonho Lee, Wonjae Kim, Wonpyo Park et al.

Learning compact discrete representations of data is a key task on its own or for facilitating subsequent processing of data. In this paper we present a model that produces Discrete InfoMax Codes (DIMCO); we learn a probabilistic encoder that yields k-way d-dimensional codes associated with input data. Our model's learning objective is to maximize the mutual information between codes and labels with a regularization, which enforces entries of a codeword to be as independent as possible. We show that the infomax principle also justifies previous loss functions (e.g., cross-entropy) as its special cases. Our analysis also shows that using shorter codes, as DIMCO does, reduces overfitting in the context of few-shot classification. Through experiments in various domains, we observe this implicit meta-regularization effect of DIMCO. Furthermore, we show that the codes learned by DIMCO are efficient in terms of both memory and retrieval time compared to previous methods.

LGOct 1, 2018
Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks

Juho Lee, Yoonho Lee, Jungtaek Kim et al.

Many machine learning tasks such as multiple instance learning, 3D shape recognition, and few-shot image classification are defined on sets of instances. Since solutions to such problems do not depend on the order of elements of the set, models used to address them should be permutation invariant. We present an attention-based neural network module, the Set Transformer, specifically designed to model interactions among elements in the input set. The model consists of an encoder and a decoder, both of which rely on attention mechanisms. In an effort to reduce computational complexity, we introduce an attention scheme inspired by inducing point methods from sparse Gaussian process literature. It reduces the computation time of self-attention from quadratic to linear in the number of elements in the set. We show that our model is theoretically attractive and we evaluate it on a range of tasks, demonstrating the state-of-the-art performance compared to recent methods for set-structured data.

MLJan 17, 2018
Gradient-Based Meta-Learning with Learned Layerwise Metric and Subspace

Yoonho Lee, Seungjin Choi

Gradient-based meta-learning methods leverage gradient descent to learn the commonalities among various tasks. While previous such methods have been successful in meta-learning tasks, they resort to simple gradient descent during meta-testing. Our primary contribution is the {\em MT-net}, which enables the meta-learner to learn on each layer's activation space a subspace that the task-specific learner performs gradient descent on. Additionally, a task-specific learner of an {\em MT-net} performs gradient descent with respect to a meta-learned distance metric, which warps the activation space to be more sensitive to task identity. We demonstrate that the dimension of this learned subspace reflects the complexity of the task-specific learner's adaptation task, and also that our model is less sensitive to the choice of initial learning rates than previous gradient-based meta-learning methods. Our method achieves state-of-the-art or comparable performance on few-shot classification and regression tasks.