Compressing Many-Shots in In-Context Learning
This work addresses efficiency issues in in-context learning for LLM users, offering a domain-specific improvement that is incremental but impactful for practical deployment.
The paper tackles the problem of high memory and computational costs in many-shot in-context learning for large language models by compressing prompts, achieving minimal accuracy degradation of less than 10% compared to baselines that degrade by over 20-30% at higher compression ratios.
Large Language Models (LLMs) have been shown to be able to learn different tasks without explicit finetuning when given many input-output examples / demonstrations through In-Context Learning (ICL). Increasing the number of examples, called ``shots'', improves downstream task performance but incurs higher memory and computational costs. In this work, we study an approach to improve the memory and computational efficiency of ICL inference by compressing the many-shot prompts. Given many shots comprising t tokens, our goal is to generate a m soft-token summary, where m < t. We first show that existing prompt compression methods are ineffective for many-shot compression, and simply using fewer shots as a baseline is surprisingly strong. To achieve effective compression, we find that: (a) a stronger compressor model with more trainable parameters is necessary, and (b) compressing many-shot representations at each transformer layer enables more fine-grained compression by providing each layer with its own compressed representation. Based on these insights, we propose MemCom, a layer-wise compression method. We systematically evaluate various compressor models and training approaches across different model sizes (2B and 7B), architectures (Gemma and Mistral), many-shot sequence lengths (3k-6k tokens), and compression ratios (3x to 8x). MemCom outperforms strong baselines across all compression ratios on multiple classification tasks with large label sets. Notably, while baseline performance degrades sharply at higher compression ratios, often by over 20-30%, MemCom maintains high accuracy with minimal degradation, typically dropping by less than 10%.