Variational Schrödinger Diffusion Models
This work addresses computational bottlenecks in diffusion models for researchers and practitioners, offering a more tuning-friendly and scalable approach, though it is incremental as it builds on existing Schrödinger bridge methods.
The paper tackles the scalability and training cost issues of Schrödinger bridge diffusion models by introducing variational inference to linearize forward score functions, resulting in the Variational Schrödinger Diffusion Model (VSDM) that achieves efficient generation of anisotropic shapes and competitive performance on datasets like CIFAR10 and time series modeling.
Schrödinger bridge (SB) has emerged as the go-to method for optimizing transportation plans in diffusion models. However, SB requires estimating the intractable forward score functions, inevitably resulting in the costly implicit training loss based on simulated trajectories. To improve the scalability while preserving efficient transportation plans, we leverage variational inference to linearize the forward score functions (variational scores) of SB and restore simulation-free properties in training backward scores. We propose the variational Schrödinger diffusion model (VSDM), where the forward process is a multivariate diffusion and the variational scores are adaptively optimized for efficient transport. Theoretically, we use stochastic approximation to prove the convergence of the variational scores and show the convergence of the adaptively generated samples based on the optimal variational scores. Empirically, we test the algorithm in simulated examples and observe that VSDM is efficient in generations of anisotropic shapes and yields straighter sample trajectories compared to the single-variate diffusion. We also verify the scalability of the algorithm in real-world data and achieve competitive unconditional generation performance in CIFAR10 and conditional generation in time series modeling. Notably, VSDM no longer depends on warm-up initializations and has become tuning-friendly in training large-scale experiments.