The Anatomy of a Triton Attention Kernel
This work addresses the need for cross-platform LLM inference without low-level tuning, offering a portable solution that benefits developers and researchers deploying models on different GPUs, though it is incremental in optimizing an existing component.
The paper tackled the problem of developing a portable and efficient LLM inference platform across hardware architectures by creating a paged attention kernel using Triton, achieving state-of-the-art performance on NVIDIA and AMD GPUs with a speed improvement from 19.7% to 105.9% of the state-of-the-art.
A long-standing goal in both industry and academia is to develop an LLM inference platform that is portable across hardware architectures, eliminates the need for low-level hand-tuning, and still delivers best-in-class efficiency. In this work, we demonstrate that portable, efficient cross-platform LLM inference is indeed possible and share our experience. We develop a state-of-the-art paged attention kernel, the core performance-critical component of many LLM deployments, that builds exclusively on the domain-specific just-in-time compiled language Triton to achieve state-of-the-art performance on both NVIDIA and AMD GPUs. We describe our high-level approach, the key algorithmic and system-level improvements, the parameter auto-tuning required to unlock efficiency, and the integrations into a popular inference server that are necessary to bring the performance of a generic Triton attention kernel from 19.7% of the state-of-the-art to 105.9%. Our results highlight how open-source domain-specific languages can be leveraged to unlock model portability across different GPU vendors.