JAX-MPM: A Learning-Augmented Differentiable Meshfree Framework for GPU-Accelerated Lagrangian Simulation and Geophysical Inverse Modeling
This work provides a scalable platform for fast physical simulation and data assimilation in geomechanics and geophysical hazards, representing an incremental improvement by integrating differentiable programming with existing meshfree methods.
The paper tackled the challenge of efficient and differentiable simulation for complex geophysical systems by developing JAX-MPM, a GPU-accelerated meshfree solver, achieving high-resolution 3D simulations with 2.7 million particles in 22 seconds for 1000 time steps and enabling gradient-based optimization for inverse modeling tasks like parameter estimation.
Differentiable programming has emerged as a powerful paradigm in scientific computing, enabling automatic differentiation through simulation pipelines and naturally supporting both forward and inverse modeling. We present JAX-MPM, a general-purpose differentiable meshfree solver based on the material point method (MPM) and implemented in the modern JAX architecture. The solver adopts a hybrid Eulerian-Lagrangian framework to capture large deformations, frictional contact, and inelastic material behavior, with emphasis on geomechanics and geophysical hazard applications. Leveraging GPU acceleration and automatic differentiation, JAX-MPM enables efficient gradient-based optimization directly through its time-stepping solvers and supports joint training of physical models with deep learning to infer unknown system conditions and uncover hidden constitutive parameters. We validate JAX-MPM through a series of 2D and 3D benchmark simulations, including dam-break and granular collapse problems, demonstrating both numerical accuracy and GPU-accelerated performance. Results show that a high-resolution 3D granular cylinder collapse with 2.7 million particles completes 1000 time steps in approximately 22 seconds (single precision) and 98 seconds (double precision) on a single GPU. Beyond high-fidelity forward modeling, we demonstrate the framework's inverse modeling capabilities through tasks such as velocity field reconstruction and the estimation of spatially varying friction from sparse data. In particular, JAX-MPM accommodates data assimilation from both Lagrangian (particle-based) and Eulerian (region-based) observations, and can be seamlessly coupled with neural network representations. These results establish JAX-MPM as a unified and scalable differentiable meshfree platform that advances fast physical simulation and data assimilation for complex solid and geophysical systems.