CLJul 2, 2021

Learned Token Pruning for Transformers

arXiv:2107.00910v3218 citationsHas Code
Originality Incremental advance
AI Analysis

This addresses the problem of deploying transformers efficiently for practitioners by reducing computational costs, though it is an incremental improvement over existing token pruning methods.

The paper tackles the high inference cost of transformer models by introducing Learned Token Pruning (LTP), which adaptively removes unimportant tokens during processing, resulting in up to 2.1x FLOPs reduction with less than 1% accuracy drop and up to 2.5% higher accuracy compared to prior methods.

Deploying transformer models in practice is challenging due to their inference cost, which scales quadratically with input sequence length. To address this, we present a novel Learned Token Pruning (LTP) method which adaptively removes unimportant tokens as an input sequence passes through transformer layers. In particular, LTP prunes tokens with an attention score below a threshold value which is learned for each layer during training. Our threshold-based method allows the length of the pruned sequence to vary adaptively based on the input sequence, and avoids algorithmically expensive operations such as top-k token selection. We extensively test the performance of LTP on GLUE tasks and show that our method outperforms the prior state-of-the-art token pruning methods by up to ~2.5% higher accuracy with the same amount of FLOPs. In particular, LTP achieves up to 2.1x FLOPs reduction with less than 1% accuracy drop, which results in up to 1.9x and 2.0x throughput improvement on Intel Haswell CPUs and NVIDIA V100 GPUs, respectively. Furthermore, we demonstrate that LTP is more robust than prior methods to variations on input sentence lengths. Our code has been developed in PyTorch and has been open-sourced.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes