PLCLLGNov 30, 2023

Automatic Functional Differentiation in JAX

arXiv:2311.18727v27 citationsh-index: 2Has Code
Originality Synthesis-oriented
AI Analysis

This work provides a tool for researchers and practitioners in computational fields like physics and machine learning who need functional derivatives, though it is incremental as it builds on JAX's existing differentiation framework.

The authors extended JAX to automatically differentiate higher-order functions (functionals and operators) by representing functions as generalized arrays and implementing primitive operators with linearization and transposition rules, enabling functional gradients to be computed in the same syntax as regular functions and producing gradients as callable Python functions.

We extend JAX with the capability to automatically differentiate higher-order functions (functionals and operators). By representing functions as a generalization of arrays, we seamlessly use JAX's existing primitive system to implement higher-order functions. We present a set of primitive operators that serve as foundational building blocks for constructing several key types of functionals. For every introduced primitive operator, we derive and implement both linearization and transposition rules, aligning with JAX's internal protocols for forward and reverse mode automatic differentiation. This enhancement allows for functional differentiation in the same syntax traditionally use for functions. The resulting functional gradients are themselves functions ready to be invoked in python. We showcase this tool's efficacy and simplicity through applications where functional derivatives are indispensable. The source code of this work is released at https://github.com/sail-sg/autofd .

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes