Sinkhorn Distance Minimization for Knowledge Distillation
This work improves knowledge distillation for compressing large language models, which is crucial for deploying efficient AI in resource-constrained environments, though it appears incremental as it builds on existing KD frameworks with a new divergence measure.
The paper tackles the problem of knowledge distillation (KD) for compressing large language models (LLMs) by addressing issues with existing divergence measures like KL, RKL, and JS, which fail when teacher and student distributions have little overlap; it proposes Sinkhorn Knowledge Distillation (SinKD) using the Sinkhorn distance and batch-wise reformulation, achieving superiority over state-of-the-art methods on GLUE and SuperGLUE benchmarks across various LLM architectures.
Knowledge distillation (KD) has been widely adopted to compress large language models (LLMs). Existing KD methods investigate various divergence measures including the Kullback-Leibler (KL), reverse Kullback-Leibler (RKL), and Jensen-Shannon (JS) divergences. However, due to limitations inherent in their assumptions and definitions, these measures fail to deliver effective supervision when few distribution overlap exists between the teacher and the student. In this paper, we show that the aforementioned KL, RKL, and JS divergences respectively suffer from issues of mode-averaging, mode-collapsing, and mode-underestimation, which deteriorates logits-based KD for diverse NLP tasks. We propose the Sinkhorn Knowledge Distillation (SinKD) that exploits the Sinkhorn distance to ensure a nuanced and precise assessment of the disparity between teacher and student distributions. Besides, profit by properties of the Sinkhorn metric, we can get rid of sample-wise KD that restricts the perception of divergence in each teacher-student sample pair. Instead, we propose a batch-wise reformulation to capture geometric intricacies of distributions across samples in the high-dimensional space. Comprehensive evaluation on GLUE and SuperGLUE, in terms of comparability, validity, and generalizability, highlights our superiority over state-of-the-art methods on all kinds of LLMs with encoder-only, encoder-decoder, and decoder-only architectures.