LGAug 21, 2024

MARLIN: Mixed-Precision Auto-Regressive Parallel Inference on Large Language Models

arXiv:2408.11743v146 citationsh-index: 41
Originality Incremental advance
AI Analysis

This addresses the efficiency of LLM serving in practical multi-user settings, representing an incremental improvement by optimizing existing quantization techniques for batched workloads.

The paper tackles the problem of achieving speedups from weight quantization in batched inference for large language models (LLMs) with multiple parallel clients, and shows that MARLIN kernels can support batch sizes up to 16-32 with near-maximum 4x speedup and up to 64-128 with significant acceleration, leading to end-to-end inference speedups of up to 2.8x.

As inference on Large Language Models (LLMs) emerges as an important workload in machine learning applications, weight quantization has become a standard technique for efficient GPU deployment. Quantization not only reduces model size, but has also been shown to yield substantial speedups for single-user inference, due to reduced memory movement, with low accuracy impact. Yet, it remains open whether speedups are achievable also in \emph{batched} settings with multiple parallel clients, which are highly relevant for practical serving. It is unclear whether GPU kernels can be designed to remain practically memory-bound, while supporting the substantially increased compute requirements of batched workloads. This paper resolves this question positively by describing the design of Mixed-precision Auto-Regressive LINear kernels, called MARLIN. Concretely, given a model whose weights are compressed via quantization to, e.g., 4 bits per element, MARLIN shows that batchsizes up to 16-32 can be supported with close to maximum ($4\times$) quantization speedup, and larger batchsizes up to 64-128 with gradually decreasing, but still significant, acceleration. MARLIN accomplishes this via a combination of techniques, such as asynchronous memory access, complex task scheduling and pipelining, and bespoke quantization support. Our experiments show that MARLIN's near-optimal performance on individual LLM layers across different scenarios can also lead to end-to-end LLM inference speedups (of up to $2.8\times$) when integrated with the popular vLLM serving engine. Finally, MARLIN is extensible to further compression techniques, like NVIDIA 2:4 sparsity, leading to additional speedups.

Code Implementations2 repos
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes