RadixMLP -- Intra-batch Deduplication for Causal Transformers
This work addresses efficiency issues for users of causal transformer models in serving scenarios, offering a practical optimization for batch inference workloads.
The paper tackles the problem of redundant MLP activation computations in causal transformer batch inference due to shared prefixes, introducing RadixMLP to eliminate this redundancy and achieving speedups of 1.44-1.59x in realistic reranking workloads and up to 5x in synthetic benchmarks.
Batch inference workloads for causal transformer models frequently process sequences that share common prefixes, such as system prompts, few-shot examples, or shared queries. Standard inference engines treat each sequence independently, redundantly recomputing identical MLP activations for every copy of the shared prefix. We introduce RadixMLP, a technique that exploits the position-wise nature of MLPs, LayerNorms, linear projections, and embeddings to eliminate this redundancy. RadixMLP dynamically maps batches to a prefix trie, gathering shared segments into a compressed representation for position-wise computation and scattering results back only at attention boundaries. RadixMLP is stateless and operates within a single forward pass. In end-to-end serving benchmarks on MS~MARCO v1.1 with Qwen3 models (0.6B to 8B parameters), RadixMLP achieves 1.44-1.59$\times$ speedups in realistic reranking workloads, with up to $5\times$ speedups on synthetic benchmarks with longer shared prefixes. Our code is available at https://github.com/michaelfeil/radix-mlp.