High-Layer Attention Pruning with Rescaling
This work addresses the need for efficient compression of large language models to reduce inference latency, offering an incremental improvement over conventional pruning techniques.
The paper tackles the problem of pruning large language models by proposing a method that strategically prunes attention heads in higher layers and uses adaptive rescaling to maintain representation scale, resulting in consistent outperformance over existing structured pruning methods across 27 datasets and various LLMs.
Pruning is a highly effective approach for compressing large language models (LLMs), significantly reducing inference latency. However, conventional training-free structured pruning methods often employ a heuristic metric that indiscriminately removes some attention heads across all pruning layers, without considering their positions within the network architecture. In this work, we propose a novel pruning algorithm that strategically prunes attention heads in the model's higher layers. Since the removal of attention heads can alter the magnitude of token representations, we introduce an adaptive rescaling parameter that calibrates the representation scale post-pruning to counteract this effect. We conduct comprehensive experiments on a wide range of LLMs, including LLaMA3.1-8B, Mistral-7B-v0.3, Qwen2-7B, and Gemma2-9B. Our evaluation includes both generation and discriminative tasks across 27 datasets. The results consistently demonstrate that our method outperforms existing structured pruning methods. This improvement is particularly notable in generation tasks, where our approach significantly outperforms existing baselines.