DCAIJan 5, 2024

Accelerating a Triton Fused Kernel for W4A16 Quantized Inference with SplitK work decomposition

arXiv:2402.00025v23 citationsh-index: 36
Originality Incremental advance
AI Analysis

This work addresses the need for faster inference in large language models, but it is incremental as it optimizes an existing method for a specific bottleneck.

The paper tackles the problem of accelerating W4A16 quantized inference for foundation models by developing a fused kernel that combines dequantization and GEMM using SplitK work decomposition, achieving average speed improvements of 65% on A100 and 124% on H100 for skinny matrix multiplications.

We propose an implementation of an efficient fused matrix multiplication kernel for W4A16 quantized inference, where we perform dequantization and GEMM in a fused kernel using a SplitK work decomposition. Our implementation shows improvement for the type of skinny matrix-matrix multiplications found in foundation model inference workloads. In particular, this paper surveys the type of matrix multiplication between a skinny activation matrix and a square weight matrix. Our results show an average of 65% speed improvement on A100, and an average of 124% speed improvement on H100 (with a peak of 295%) for a range of matrix dimensions including those found in a llama-style model, where m < n = k.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes