laplax -- Laplace Approximations with JAX
This work provides a tool for researchers working on Bayesian neural networks and uncertainty quantification, but it is incremental as it focuses on implementation rather than new algorithmic breakthroughs.
The authors tackled the need for scalable Bayesian uncertainty quantification in deep neural networks by introducing laplax, an open-source Python package that enables efficient Laplace approximations using JAX, facilitating research in this area.
The Laplace approximation provides a scalable and efficient means of quantifying weight-space uncertainty in deep neural networks, enabling the application of Bayesian tools such as predictive uncertainty and model selection via Occam's razor. In this work, we introduce laplax, a new open-source Python package for performing Laplace approximations with jax. Designed with a modular and purely functional architecture and minimal external dependencies, laplax offers a flexible and researcher-friendly framework for rapid prototyping and experimentation. Its goal is to facilitate research on Bayesian neural networks, uncertainty quantification for deep learning, and the development of improved Laplace approximation techniques.