Faster Diffusion Sampling with Randomized Midpoints: Sequential and Parallel
This work provides faster and parallelizable sampling algorithms for diffusion models, which is crucial for practical applications in generative AI, though it is incremental as it builds on existing methods.
The paper tackles the problem of slow sampling in diffusion models by proposing a new scheme based on randomized midpoints, achieving a dimension dependence of Õ(d^{5/12}) in total variation distance, which improves upon prior Õ(√d) bounds, and enables parallelization to Õ(log² d) rounds.
Sampling algorithms play an important role in controlling the quality and runtime of diffusion model inference. In recent years, a number of works~\cite{chen2023sampling,chen2023ode,benton2023error,lee2022convergence} have proposed schemes for diffusion sampling with provable guarantees; these works show that for essentially any data distribution, one can approximately sample in polynomial time given a sufficiently accurate estimate of its score functions at different noise levels. In this work, we propose a new scheme inspired by Shen and Lee's randomized midpoint method for log-concave sampling~\cite{ShenL19}. We prove that this approach achieves the best known dimension dependence for sampling from arbitrary smooth distributions in total variation distance ($\widetilde O(d^{5/12})$ compared to $\widetilde O(\sqrt{d})$ from prior work). We also show that our algorithm can be parallelized to run in only $\widetilde O(\log^2 d)$ parallel rounds, constituting the first provable guarantees for parallel sampling with diffusion models. As a byproduct of our methods, for the well-studied problem of log-concave sampling in total variation distance, we give an algorithm and simple analysis achieving dimension dependence $\widetilde O(d^{5/12})$ compared to $\widetilde O(\sqrt{d})$ from prior work.