NEAILGDec 4, 2024

JPC: Flexible Inference for Predictive Coding Networks in JAX

arXiv:2412.03676v13 citationsh-index: 4Has Code
Originality Synthesis-oriented
AI Analysis

This is an incremental improvement for researchers in machine learning, providing a more efficient tool for studying and applying Predictive Coding Networks.

The authors tackled the problem of slow training for Predictive Coding Networks by introducing JPC, a JAX library that uses ordinary differential equation solvers for inference, resulting in significantly faster runtimes with comparable performance on various tasks.

We introduce JPC, a JAX library for training neural networks with Predictive Coding. JPC provides a simple, fast and flexible interface to train a variety of PC networks (PCNs) including discriminative, generative and hybrid models. Unlike existing libraries, JPC leverages ordinary differential equation solvers to integrate the gradient flow inference dynamics of PCNs. We find that a second-order solver achieves significantly faster runtimes compared to standard Euler integration, with comparable performance on a range of tasks and network depths. JPC also provides some theoretical tools that can be used to study PCNs. We hope that JPC will facilitate future research of PC. The code is available at https://github.com/thebuckleylab/jpc.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes