A research framework for writing differentiable PDE discretizations in JAX
This work provides a framework for researchers in fields like reinforcement learning and optimal control to easily implement differentiable PDE discretizations, though it is incremental as it builds on existing concepts of differentiable simulators.
The authors tackled the challenge of creating differentiable simulators for PDEs by proposing a design pattern to build a library of differentiable operators and discretizations in JAX, demonstrating it on an acoustic optimization problem with gradient descent to optimize the speed of sound in an acoustic lens.
Differentiable simulators are an emerging concept with applications in several fields, from reinforcement learning to optimal control. Their distinguishing feature is the ability to calculate analytic gradients with respect to the input parameters. Like neural networks, which are constructed by composing several building blocks called layers, a simulation often requires computing the output of an operator that can itself be decomposed into elementary units chained together. While each layer of a neural network represents a specific discrete operation, the same operator can have multiple representations, depending on the discretization employed and the research question that needs to be addressed. Here, we propose a simple design pattern to construct a library of differentiable operators and discretizations, by representing operators as mappings between families of continuous functions, parametrized by finite vectors. We demonstrate the approach on an acoustic optimization problem, where the Helmholtz equation is discretized using Fourier spectral methods, and differentiability is demonstrated using gradient descent to optimize the speed of sound of an acoustic lens. The proposed framework is open-sourced and available at \url{https://github.com/ucl-bug/jaxdf}