FlashFormer: Whole-Model Kernels for Efficient Low-Batch Inference
This work addresses the bottleneck of low-batch inference for transformer-based LLMs, particularly benefiting edge deployment and latency-sensitive applications, though it is incremental as it builds on existing kernel optimization approaches.
The paper tackles the problem of inefficient low-batch inference for large language models, which is critical for edge deployment and latency-sensitive applications, by introducing FlashFormer, a specialized kernel that achieves nontrivial speedups across various model sizes and quantization settings compared to existing state-of-the-art inference kernels.
The size and compute characteristics of modern large language models have led to an increased interest in developing specialized kernels tailored for training and inference. Existing kernels primarily optimize for compute utilization, targeting the large-batch training and inference settings. However, low-batch inference, where memory bandwidth and kernel launch overheads contribute are significant factors, remains important for many applications of interest such as in edge deployment and latency-sensitive applications. This paper describes FlashFormer, a proof-of-concept kernel for accelerating single-batch inference for transformer-based large language models. Across various model sizes and quantizations settings, we observe nontrivial speedups compared to existing state-of-the-art inference kernels.