LGAIMLFeb 8, 2022

PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX

arXiv:2202.04110v48 citationsHas Code
Originality Incremental advance
AI Analysis

This provides a tool for researchers and practitioners in machine learning to efficiently handle complex probabilistic models, though it is incremental as it builds on existing methods like factor graphs and JAX.

The authors tackled the problem of specifying and performing inference on discrete probabilistic graphical models by introducing PGMax, an open-source Python package that uses factor graphs and loopy belief propagation in JAX, resulting in up to three orders-of-magnitude speedups in inference time compared to existing alternatives.

PGMax is an open-source Python package for (a) easily specifying discrete Probabilistic Graphical Models (PGMs) as factor graphs; and (b) automatically running efficient and scalable loopy belief propagation (LBP) in JAX. PGMax supports general factor graphs with tractable factors, and leverages modern accelerators like GPUs for inference. Compared with existing alternatives, PGMax obtains higher-quality inference results with up to three orders-of-magnitude inference time speedups. PGMax additionally interacts seamlessly with the rapidly growing JAX ecosystem, opening up new research possibilities. Our source code, examples and documentation are available at https://github.com/deepmind/PGMax.

Code Implementations2 repos
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes