Consistent Sampling and Simulation: Molecular Dynamics with Energy-Based Diffusion Models
This addresses a fundamental inconsistency in diffusion models for molecular simulation, enabling more reliable sampling and simulation of biomolecules like proteins and dipeptides, though it is incremental as it builds on existing diffusion methods.
The paper tackled the inconsistency between diffusion model sampling and energy-based interpretation in molecular dynamics by identifying score inaccuracies at small diffusion timesteps and proposing a Fokker-Planck regularization to enforce consistency. The result includes a state-of-the-art transferable Boltzmann emulator for dipeptides that achieves improved consistency and efficient sampling.
In recent years, diffusion models trained on equilibrium molecular distributions have proven effective for sampling biomolecules. Beyond direct sampling, the score of such a model can also be used to derive the forces that act on molecular systems. However, while classical diffusion sampling usually recovers the training distribution, the corresponding energy-based interpretation of the learned score is often inconsistent with this distribution, even for low-dimensional toy systems. We trace this inconsistency to inaccuracies of the learned score at very small diffusion timesteps, where the model must capture the correct evolution of the data distribution. In this regime, diffusion models fail to satisfy the Fokker--Planck equation, which governs the evolution of the score. We interpret this deviation as one source of the observed inconsistencies and propose an energy-based diffusion model with a Fokker--Planck-derived regularization term to enforce consistency. We demonstrate our approach by sampling and simulating multiple biomolecular systems, including fast-folding proteins, and by introducing a state-of-the-art transferable Boltzmann emulator for dipeptides that supports simulation and achieves improved consistency and efficient sampling. Our code, model weights, and self-contained JAX and PyTorch notebooks are available at https://github.com/noegroup/ScoreMD.