LGAIMLJun 25, 2025

Diffusion Tree Sampling: Scalable inference-time alignment of diffusion models

arXiv:2506.20701v127 citationsh-index: 24
Originality Highly original
AI Analysis

This provides a scalable method for inference-time alignment of diffusion models, addressing inefficiencies in existing steering approaches for generative modeling.

The paper tackles the problem of adapting pretrained diffusion models to new objectives at inference time by introducing Diffusion Tree Sampling (DTS), which reuses past computations to improve sample quality and efficiency. Results show DTS matches baseline FID with up to 10x less compute on image generation and achieves up to 5x less compute for high-reward samples in text-to-image tasks.

Adapting a pretrained diffusion model to new objectives at inference time remains an open problem in generative modeling. Existing steering methods suffer from inaccurate value estimation, especially at high noise levels, which biases guidance. Moreover, information from past runs is not reused to improve sample quality, resulting in inefficient use of compute. Inspired by the success of Monte Carlo Tree Search, we address these limitations by casting inference-time alignment as a search problem that reuses past computations. We introduce a tree-based approach that samples from the reward-aligned target density by propagating terminal rewards back through the diffusion chain and iteratively refining value estimates with each additional generation. Our proposed method, Diffusion Tree Sampling (DTS), produces asymptotically exact samples from the target distribution in the limit of infinite rollouts, and its greedy variant, Diffusion Tree Search (DTS$^\star$), performs a global search for high reward samples. On MNIST and CIFAR-10 class-conditional generation, DTS matches the FID of the best-performing baseline with up to $10\times$ less compute. In text-to-image generation and language completion tasks, DTS$^\star$ effectively searches for high reward samples that match best-of-N with up to $5\times$ less compute. By reusing information from previous generations, we get an anytime algorithm that turns additional compute into steadily better samples, providing a scalable approach for inference-time alignment of diffusion models.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes