CVAIARLGFeb 4, 2023

Oscillation-free Quantization for Low-bit Vision Transformers

arXiv:2302.02210v365 citationsh-index: 36Has Code
Originality Incremental advance
AI Analysis

This work addresses training instability and sub-optimal performance in quantized vision transformers, which is crucial for efficient deployment on resource-constrained devices, though it is incremental as it builds on existing quantization methods.

The paper tackles weight oscillation in quantization-aware training for low-bit Vision Transformers, identifying learnable scaling factors and query-key interdependence as causes, and proposes three techniques that achieve state-of-the-art accuracy improvements of 9.8% and 7.7% on ImageNet for 2-bit DeiT models.

Weight oscillation is an undesirable side effect of quantization-aware training, in which quantized weights frequently jump between two quantized levels, resulting in training instability and a sub-optimal final model. We discover that the learnable scaling factor, a widely-used $\textit{de facto}$ setting in quantization aggravates weight oscillation. In this study, we investigate the connection between the learnable scaling factor and quantized weight oscillation and use ViT as a case driver to illustrate the findings and remedies. In addition, we also found that the interdependence between quantized weights in $\textit{query}$ and $\textit{key}$ of a self-attention layer makes ViT vulnerable to oscillation. We, therefore, propose three techniques accordingly: statistical weight quantization ($\rm StatsQ$) to improve quantization robustness compared to the prevalent learnable-scale-based method; confidence-guided annealing ($\rm CGA$) that freezes the weights with $\textit{high confidence}$ and calms the oscillating weights; and $\textit{query}$-$\textit{key}$ reparameterization ($\rm QKR$) to resolve the query-key intertwined oscillation and mitigate the resulting gradient misestimation. Extensive experiments demonstrate that these proposed techniques successfully abate weight oscillation and consistently achieve substantial accuracy improvement on ImageNet. Specifically, our 2-bit DeiT-T/DeiT-S algorithms outperform the previous state-of-the-art by 9.8% and 7.7%, respectively. Code and models are available at: https://github.com/nbasyl/OFQ.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes