NELGSep 4, 2024

SNNAX -- Spiking Neural Networks in JAX

arXiv:2409.02842v25 citationsh-index: 9
Originality Incremental advance
AI Analysis

This provides a more efficient tool for researchers and engineers prototyping biologically inspired models and neuromorphic hardware, though it is incremental as it builds on existing frameworks like JAX.

The authors tackled the need for fast and flexible simulation tools for Spiking Neural Networks (SNNs) by developing SNNAX, a JAX-based framework that achieves PyTorch-like intuitiveness and JAX-like execution speed, with performance metrics showing significant improvements in simulation speed compared to other frameworks.

Spiking Neural Networks (SNNs) simulators are essential tools to prototype biologically inspired models and neuromorphic hardware architectures and predict their performance. For such a tool, ease of use and flexibility are critical, but so is simulation speed especially given the complexity inherent to simulating SNN. Here, we present SNNAX, a JAX-based framework for simulating and training such models with PyTorch-like intuitiveness and JAX-like execution speed. SNNAX models are easily extended and customized to fit the desired model specifications and target neuromorphic hardware. Additionally, SNNAX offers key features for optimizing the training and deployment of SNNs such as flexible automatic differentiation and just-in-time compilation. We evaluate and compare SNNAX to other commonly used machine learning (ML) frameworks used for programming SNNs. We provide key performance metrics, best practices, documented examples for simulating SNNs in SNNAX, and implement several benchmarks used in the literature.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes