Contrastive Conditional-Unconditional Alignment for Long-tailed Diffusion Model
This addresses the issue of imbalanced training data for class-conditional diffusion models, which is a domain-specific problem in image generation, and is incremental as it builds on existing diffusion models with novel loss functions.
The paper tackles the problem of class-conditional image synthesis with long-tailed data, which causes mode collapse and reduced diversity for tail classes, by introducing two loss functions that improve diversity and fidelity for tail classes without harming head class quality, achieving superior performance on datasets like ImageNet-LT at 256x256 resolution.
Training data for class-conditional image synthesis often exhibit a long-tailed distribution with limited images for tail classes. Such an imbalance causes mode collapse and reduces the diversity of synthesized images for tail classes. For class-conditional diffusion models trained on imbalanced data, we aim to improve the diversity and fidelity of tail class images without compromising the quality of head class images. We achieve this by introducing two simple but highly effective loss functions. Firstly, we employ an Unsupervised Contrastive Loss (UCL) utilizing negative samples to increase the distance/dissimilarity among synthetic images. Such regularization is coupled with a standard trick of batch resampling to further diversify tail-class images. Our second loss is an Alignment Loss (AL) that aligns class-conditional generation with unconditional generation at large timesteps. This second loss makes the denoising process insensitive to class conditions for the initial steps, which enriches tail classes through knowledge sharing from head classes. We successfully leverage contrastive learning and conditional-unconditional alignment for class-imbalanced diffusion models. Our framework is easy to implement as demonstrated on both U-Net based architecture and Diffusion Transformer. Our method outperforms vanilla denoising diffusion probabilistic models, score-based diffusion model, and alternative methods for class-imbalanced image generation across various datasets, in particular ImageNet-LT with 256x256 resolution.