JAX-MPM: A Learning-Augmented Differentiable Meshfree Framework for GPU-Accelerated Lagrangian Simulation and Geophysical Inverse Modeling
Journal:
arXiv
Published Date:
Jul 6, 2025
Abstract
Differentiable programming that enables automatic differentiation through
simulation pipelines has emerged as a powerful paradigm in scientific
computing, supporting both forward and inverse modeling and facilitating
integration with deep learning frameworks. We present JAX-MPM, a
general-purpose differentiable meshfree solver within a hybrid
Lagrangian-Eulerian framework, tailored for simulating complex continuum
mechanics involving large deformations, frictional contact, and inelastic
material behavior, with emphasis on geomechanics and geophysical hazard
applications. Built on the material point method (MPM) and implemented using
the JAX computing framework, JAX-MPM is fully differentiable and
GPU-accelerated, enabling efficient gradient-based optimization directly
through time-stepping solvers. It supports joint training of physical models
and neural networks, allowing the learning of embedded closures and neural
constitutive models. We validate JAX-MPM on several 2D and 3D benchmarks,
including dam-breaks and granular collapses, demonstrating its accuracy and
performance. A high-resolution 3D granular cylinder collapse with 2.7 million
particles completes 1000 steps in ~22 seconds (single precision) and ~98
seconds (double precision) on a single GPU. Beyond forward modeling, we
demonstrate inverse modeling capabilities such as velocity field reconstruction
and spatially varying friction estimation. These results establish JAX-MPM as a
unified, scalable platform for differentiable meshfree simulation and
data-driven geomechanical inference.