Fine-Tuning Masked Diffusion for Provable Self-Correction
This addresses the challenge of improving inference quality in generative models for applications like text and code generation, though it is incremental as it builds on existing MDM frameworks.
The paper tackles the problem of enabling self-correction in Masked Diffusion Models (MDMs) by introducing PRISM, a lightweight, model-agnostic method that learns per-token quality scores without requiring architectural changes or reinforcement learning, achieving advances in domains like Sudoku, text generation (170M), and code generation with LLaDA (8B).
A natural desideratum for generative models is self-correction--detecting and revising low-quality tokens at inference. While Masked Diffusion Models (MDMs) have emerged as a promising approach for generative modeling in discrete spaces, their capacity for self-correction remains poorly understood. Prior attempts to incorporate self-correction into MDMs either require overhauling MDM architectures/training or rely on imprecise proxies for token quality, limiting their applicability. Motivated by this, we introduce PRISM--Plug-in Remasking for Inference-time Self-correction of Masked Diffusions--a lightweight, model-agnostic approach that applies to any pretrained MDM. Theoretically, PRISM defines a self-correction loss that provably learns per-token quality scores, without RL or a verifier. These quality scores are computed in the same forward pass with MDM and used to detect low-quality tokens. Empirically, PRISM advances MDM inference across domains and scales: Sudoku; unconditional text (170M); and code with LLaDA (8B).