Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro
This work provides a flexible and accelerated probabilistic programming library for researchers and practitioners, though it is incremental as it builds on existing Pyro and JAX frameworks.
The authors tackled the challenge of integrating Pyro's probabilistic programming interface with JAX's hardware acceleration by developing NumPyro, which uses composable effect handlers to enable an iterative No-U-Turn Sampler that is end-to-end JIT compiled, resulting in significantly faster performance than existing alternatives in both small and large dataset regimes.
NumPyro is a lightweight library that provides an alternate NumPy backend to the Pyro probabilistic programming language with the same modeling interface, language primitives and effect handling abstractions. Effect handlers allow Pyro's modeling API to be extended to NumPyro despite its being built atop a fundamentally different JAX-based functional backend. In this work, we demonstrate the power of composing Pyro's effect handlers with the program transformations that enable hardware acceleration, automatic differentiation, and vectorization in JAX. In particular, NumPyro provides an iterative formulation of the No-U-Turn Sampler (NUTS) that can be end-to-end JIT compiled, yielding an implementation that is much faster than existing alternatives in both the small and large dataset regimes.