LGAIMLMay 14, 2021

Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets

arXiv:2105.06987v121 citations
Originality Incremental advance
AI Analysis

This work addresses the high inference cost of ensembles for practitioners needing efficient uncertainty estimates in tasks with many classes, representing an incremental improvement over existing distillation methods.

The paper tackled the poor convergence of Ensemble Distribution Distillation in large-scale classification tasks with many classes, proposing a new training objective that minimizes reverse KL-divergence to a Proxy-Dirichlet target, which resolved gradient issues and demonstrated effectiveness on datasets like ImageNet (1000 classes) and WMT17 En-De (40,000 classes).

Ensembles of machine learning models yield improved system performance as well as robust and interpretable uncertainty estimates; however, their inference costs may often be prohibitively high. \emph{Ensemble Distribution Distillation} is an approach that allows a single model to efficiently capture both the predictive performance and uncertainty estimates of an ensemble. For classification, this is achieved by training a Dirichlet distribution over the ensemble members' output distributions via the maximum likelihood criterion. Although theoretically principled, this criterion exhibits poor convergence when applied to large-scale tasks where the number of classes is very high. In our work, we analyze this effect and show that the Dirichlet log-likelihood criterion classes with low probability induce larger gradients than high-probability classes. This forces the model to focus on the distribution of the ensemble tail-class probabilities. We propose a new training objective that minimizes the reverse KL-divergence to a \emph{Proxy-Dirichlet} target derived from the ensemble. This loss resolves the gradient issues of Ensemble Distribution Distillation, as we demonstrate both theoretically and empirically on the ImageNet and WMT17 En-De datasets containing 1000 and 40,000 classes, respectively.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes