CLAIMar 30, 2024

DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference

arXiv:2404.00242v413 citationsh-index: 5Has CodeICLR
Originality Incremental advance
AI Analysis

This addresses performance bottlenecks in tree-based LLM inference for applications like few-shot prompting and speculative decoding, representing an incremental improvement over existing attention algorithms.

The paper tackles the inefficiency of existing inference systems for tree-structured LLM applications by proposing DeFT, a hardware-efficient attention algorithm that reduces KV cache IO by 73-99% and achieves up to 2.23x end-to-end speedup compared to state-of-the-art methods.

Large language models (LLMs) are increasingly employed for complex tasks that process multiple generation calls in a tree structure with shared prefixes of tokens, including few-shot prompting, multi-step reasoning, speculative decoding, etc. However, existing inference systems for tree-based applications are inefficient due to improper partitioning of queries and KV cache during attention calculation. This leads to two main issues: (1) a lack of memory access (IO) reuse for KV cache of shared prefixes, and (2) poor load balancing.As a result, there is redundant KV cache IO between GPU global memory and shared memory, along with low GPU utilization. To address these challenges, we propose DeFT(Decoding with Flash Tree-Attention), a hardware-efficient attention algorithm with prefix-aware and load-balanced KV cache partitions. DeFT reduces the number of read/write operations of KV cache during attention calculation through KV-Guided Grouping, a method that avoids repeatedly loading KV cache of shared prefixes in attention computation. Additionally, we propose Flattened Tree KV Splitting, a mechanism that ensures even distribution of the KV cache across partitions with little computation redundancy, enhancing GPU utilization during attention computations. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation, DeFT achieves up to 2.23/3.59x speedup in the end-to-end/attention latency across three practical tree-based workloads compared to state-of-the-art attention algorithms. Our code is available at https://github.com/LINs-lab/DeFT.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes