Prompt-based Depth Pruning of Large Language Models
This work addresses inference efficiency for users of large language models, but it is incremental as it builds on existing depth pruning techniques with a task-aware adaptation.
The paper tackles the problem of reducing inference cost in large language models by removing transformer blocks, and finds that block importance is task-dependent. It introduces PuDDing, a dynamic depth pruning algorithm that uses a lightweight router to omit blocks based on the input prompt, achieving better performance than static methods on commonsense reasoning benchmarks.
Depth pruning aims to reduce the inference cost of a large language model without any hardware-specific complications, by simply removing several less important transformer blocks. However, our empirical findings suggest that the importance of a transformer block may be highly task-dependent -- a block that is crucial for a task can be removed without degrading the accuracy on another task. Based on this observation, we develop a dynamic depth pruning algorithm, coined PuDDing (Prompt-routed Dynamic Depth Pruning), which determines which blocks to omit from the model based on the input prompt. PuDDing operates by training a lightweight router to predict the best omission set among a set of options, where this option set has also been constructed in a data-driven manner. Empirical results on commonsense reasoning benchmarks demonstrate that PuDDing effectively accelerates the inference language models, and achieves better on-task performance than static depth pruning baselines.