JAX, M.D.: A Framework for Differentiable Physics

arXiv:1912.04232v240 citationsHas Code
AI Analysis

This provides a tool for researchers in computational physics and machine learning to integrate neural networks into simulations and perform optimization, though it is incremental as a framework built on existing differentiable programming concepts.

The authors introduced JAX MD, a software framework for differentiable physics simulations focused on molecular dynamics, enabling meta-optimization through trajectory differentiation and scaling to hundreds-of-thousands of particles on a single GPU.

We introduce JAX MD, a software package for performing differentiable physics simulations with a focus on molecular dynamics. JAX MD includes a number of physics simulation environments, as well as interaction potentials and neural networks that can be integrated into these environments without writing any additional code. Since the simulations themselves are differentiable functions, entire trajectories can be differentiated to perform meta-optimization. These features are built on primitive operations, such as spatial partitioning, that allow simulations to scale to hundreds-of-thousands of particles on a single GPU. These primitives are flexible enough that they can be used to scale up workloads outside of molecular dynamics. We present several examples that highlight the features of JAX MD including: integration of graph neural networks into traditional simulations, meta-optimization through minimization of particle packings, and a multi-agent flocking simulation. JAX MD is available at www.github.com/google/jax-md.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes