Decomposing reverse-mode automatic differentiation
This is an incremental improvement for developers and researchers in machine learning and scientific computing, making AD implementations more modular and easier to manage.
The paper tackles the complexity of implementing reverse-mode automatic differentiation by decomposing it into forward-mode linearization and transposition, which simplifies joint implementation with forward-mode AD. The result is that only linear primitives need additional transposition rules, as demonstrated in systems like JAX and Dex.
We decompose reverse-mode automatic differentiation into (forward-mode) linearization followed by transposition. Doing so isolates the essential difference between forward- and reverse-mode AD, and simplifies their joint implementation. In particular, once forward-mode AD rules are defined for every primitive operation in a source language, only linear primitives require an additional transposition rule in order to arrive at a complete reverse-mode AD implementation. This is how reverse-mode AD is written in JAX and Dex.