Group Distributionally Robust Knowledge Distillation
This work addresses performance disparities in knowledge distillation for scenarios like medical imaging with underrepresented data groups, representing an incremental improvement over existing methods.
The paper tackles the problem of knowledge distillation being vulnerable to sub-population shifts, such as underrepresented groups in medical imaging, by proposing a group-aware distillation loss that dynamically focuses on low-performing groups during training, resulting in consistent improvements in worst-group accuracy on benchmark datasets including natural images and cardiac MRIs.
Knowledge distillation enables fast and effective transfer of features learned from a bigger model to a smaller one. However, distillation objectives are susceptible to sub-population shifts, a common scenario in medical imaging analysis which refers to groups/domains of data that are underrepresented in the training set. For instance, training models on health data acquired from multiple scanners or hospitals can yield subpar performance for minority groups. In this paper, inspired by distributionally robust optimization (DRO) techniques, we address this shortcoming by proposing a group-aware distillation loss. During optimization, a set of weights is updated based on the per-group losses at a given iteration. This way, our method can dynamically focus on groups that have low performance during training. We empirically validate our method, GroupDistil on two benchmark datasets (natural images and cardiac MRIs) and show consistent improvement in terms of worst-group accuracy.