One Jump Is All You Need: Short-Cutting Transformers for Early Exit Prediction with One Jump to Fit All Exit Levels
This work addresses computational efficiency for users of transformer models, but it is incremental as it builds on existing low-rank early-exit methods.
The paper tackles the problem of reducing inference costs in large language models by proposing a single low-rank shortcut for early exit prediction, achieving over a 30x reduction in parameter costs while maintaining performance comparable to multiple shortcuts across models like GPT2-XL, Phi3-Mini, and Llama2-7B.
To reduce the time and computational costs of inference of large language models, there has been interest in parameter-efficient low-rank early-exit casting of transformer hidden-representations to final-representations. Such low-rank short-cutting has been shown to outperform identity shortcuts at early model stages while offering parameter-efficiency in shortcut jumps. However, current low-rank methods maintain a separate early-exit shortcut jump to final-representations for each transformer intermediate block-level during inference. In this work, we propose selection of a single One-Jump-Fits-All (OJFA) low-rank shortcut that offers over a 30x reduction in shortcut parameter costs during inference. We show that despite this extreme reduction, our OJFA choice largely matches the performance of maintaining multiple shortcut jumps during inference and offers stable precision from all transformer block-levels for GPT2-XL, Phi3-Mini and Llama2-7B transformer models.