Spyx: A Library for Just-In-Time Compiled Optimization of Spiking Neural Networks
This work addresses the problem of high computational costs in SNN training for researchers and practitioners in neuromorphic computing, representing an incremental improvement in optimization techniques.
The authors tackled the challenge of efficiently training Spiking Neural Networks (SNNs) by developing Spyx, a JAX-based library that uses JIT compilation and pre-staging data in accelerator vRAM to achieve optimal hardware utilization on NVIDIA GPUs or Google TPUs, surpassing the performance of many existing SNN training frameworks.
As the role of artificial intelligence becomes increasingly pivotal in modern society, the efficient training and deployment of deep neural networks have emerged as critical areas of focus. Recent advancements in attention-based large neural architectures have spurred the development of AI accelerators, facilitating the training of extensive, multi-billion parameter models. Despite their effectiveness, these powerful networks often incur high execution costs in production environments. Neuromorphic computing, inspired by biological neural processes, offers a promising alternative. By utilizing temporally-sparse computations, Spiking Neural Networks (SNNs) offer to enhance energy efficiency through a reduced and low-power hardware footprint. However, the training of SNNs can be challenging due to their recurrent nature which cannot as easily leverage the massive parallelism of modern AI accelerators. To facilitate the investigation of SNN architectures and dynamics researchers have sought to bridge Python-based deep learning frameworks such as PyTorch or TensorFlow with custom-implemented compute kernels. This paper introduces Spyx, a new and lightweight SNN simulation and optimization library designed in JAX. By pre-staging data in the expansive vRAM of contemporary accelerators and employing extensive JIT compilation, Spyx allows for SNN optimization to be executed as a unified, low-level program on NVIDIA GPUs or Google TPUs. This approach achieves optimal hardware utilization, surpassing the performance of many existing SNN training frameworks while maintaining considerable flexibility.