Accelerating a Triton Fused Kernel for W4A16 Quantized Inference with SplitK work decomposition
This work addresses the need for faster inference in large language models, but it is incremental as it optimizes an existing method for a specific bottleneck.
The paper tackles the problem of accelerating W4A16 quantized inference for foundation models by developing a fused kernel that combines dequantization and GEMM using SplitK work decomposition, achieving average speed improvements of 65% on A100 and 124% on H100 for skinny matrix multiplications.
We propose an implementation of an efficient fused matrix multiplication kernel for W4A16 quantized inference, where we perform dequantization and GEMM in a fused kernel using a SplitK work decomposition. Our implementation shows improvement for the type of skinny matrix-matrix multiplications found in foundation model inference workloads. In particular, this paper surveys the type of matrix multiplication between a skinny activation matrix and a square weight matrix. Our results show an average of 65% speed improvement on A100, and an average of 124% speed improvement on H100 (with a peak of 295%) for a range of matrix dimensions including those found in a llama-style model, where m < n = k.