PokeBNN: A Binary Pursuit of Lightweight Accuracy
This work addresses the need for lightweight, accurate neural networks for inference settings, offering significant improvements in binary neural networks, though it is incremental as it builds on existing BNN methods.
The paper tackles the problem of low accuracy in binary neural networks (BNNs) for image classification by proposing PokeConv, a binary convolution block with techniques like multiple residual paths and activation tuning, applied to ResNet-50. The result is PokeBNN, which achieves state-of-the-art performance, with a small variant reducing cost by over 3x to 2.6 ACE at 70.5% top-1 accuracy and a larger one improving accuracy by over 5% to 75.6% top-1 without increasing cost.
Optimization of Top-1 ImageNet promotes enormous networks that may be impractical in inference settings. Binary neural networks (BNNs) have the potential to significantly lower the compute intensity but existing models suffer from low quality. To overcome this deficiency, we propose PokeConv, a binary convolution block which improves quality of BNNs by techniques such as adding multiple residual paths, and tuning the activation function. We apply it to ResNet-50 and optimize ResNet's initial convolutional layer which is hard to binarize. We name the resulting network family PokeBNN. These techniques are chosen to yield favorable improvements in both top-1 accuracy and the network's cost. In order to enable joint optimization of the cost together with accuracy, we define arithmetic computation effort (ACE), a hardware- and energy-inspired cost metric for quantized and binarized networks. We also identify a need to optimize an under-explored hyper-parameter controlling the binarization gradient approximation. We establish a new, strong state-of-the-art (SOTA) on top-1 accuracy together with commonly-used CPU64 cost, ACE cost and network size metrics. ReActNet-Adam, the previous SOTA in BNNs, achieved a 70.5% top-1 accuracy with 7.9 ACE. A small variant of PokeBNN achieves 70.5% top-1 with 2.6 ACE, more than 3x reduction in cost; a larger PokeBNN achieves 75.6% top-1 with 7.8 ACE, more than 5% improvement in accuracy without increasing the cost. PokeBNN implementation in JAX/Flax and reproduction instructions are available in AQT repository: https://github.com/google/aqt