Learnable Sampler Distillation for Discrete Diffusion Models
This addresses a critical bottleneck for practical applications of DDMs in domains like text and molecule generation by accelerating sampling while maintaining quality, representing a novel method rather than an incremental improvement.
The paper tackled the problem of inefficient sampling in discrete diffusion models (DDMs) by proposing learnable sampler distillation (LSD) and LSD+ to train fast, high-fidelity samplers, resulting in substantially higher sampling quality with significantly fewer steps across text generation, image generation, and synthetic tasks.
Discrete diffusion models (DDMs) have shown powerful generation ability for discrete data modalities like text and molecules. However, their practical application is hindered by inefficient sampling, requiring a large number of sampling steps. Accelerating DDMs by using larger step sizes typically introduces significant problems in generation quality, as it amplifies the impact of both the compounding decoding error due to factorized predictions and discretization error from numerical approximations, leading to a significant decrease in sampling quality. To address these challenges, we propose learnable sampler distillation (LSD), a novel approach to train fast and high-fidelity samplers for DDMs. LSD employs a distillation approach where a student sampler with a few steps learns to align its intermediate score trajectory with that of a high-quality teacher sampler with numerous steps. This alignment is achieved by optimizing learnable sampler coefficients that adaptively adjust sampling dynamics. Additionally, we further propose LSD+, which also learns time schedules that allocate steps non-uniformly. Experiments across text generation, image generation, and synthetic tasks demonstrate that our proposed approaches outperform existing samplers for DDMs, achieving substantially higher sampling quality with significantly fewer sampling steps. Our code is available at \href{https://github.com/feiyangfu/LSD}{https://github.com/feiyangfu/LSD}.