Membrane Potential Batch Normalization for Spiking Neural Networks
This work addresses a specific bottleneck in SNNs for energy-efficient AI applications, offering an incremental improvement over existing batch normalization techniques.
The paper tackles the issue of data flow disturbance in spiking neural networks (SNNs) due to membrane potential dynamics by proposing Membrane Potential Batch Normalization (MPBN), which normalizes membrane potential before the firing function and uses re-parameterization to avoid extra inference time. Experimental results show MPBN performs well on static and neuromorphic datasets, with code open-sourced.
As one of the energy-efficient alternatives of conventional neural networks (CNNs), spiking neural networks (SNNs) have gained more and more interest recently. To train the deep models, some effective batch normalization (BN) techniques are proposed in SNNs. All these BNs are suggested to be used after the convolution layer as usually doing in CNNs. However, the spiking neuron is much more complex with the spatio-temporal dynamics. The regulated data flow after the BN layer will be disturbed again by the membrane potential updating operation before the firing function, i.e., the nonlinear activation. Therefore, we advocate adding another BN layer before the firing function to normalize the membrane potential again, called MPBN. To eliminate the induced time cost of MPBN, we also propose a training-inference-decoupled re-parameterization technique to fold the trained MPBN into the firing threshold. With the re-parameterization technique, the MPBN will not introduce any extra time burden in the inference. Furthermore, the MPBN can also adopt the element-wised form, while these BNs after the convolution layer can only use the channel-wised form. Experimental results show that the proposed MPBN performs well on both popular non-spiking static and neuromorphic datasets. Our code is open-sourced at \href{https://github.com/yfguo91/MPBN}{MPBN}.