LGCLAug 7, 2024

Tree Attention: Topology-aware Decoding for Long-Context Attention on GPU clusters

arXiv:2408.04093v415 citationsh-index: 26Has Code
AI Analysis

This work addresses the bottleneck of slow cross-device decoding for large language models, offering significant speedups and efficiency improvements for researchers and practitioners using GPU clusters.

The paper tackles the problem of parallelizing exact attention computation for long-context decoding across multiple GPUs, achieving up to 8x faster decoding than state-of-the-art methods like Ring Attention, with 2x less peak memory and reduced communication volume.

Our formulation reveals that the reduction across the sequence axis can be efficiently computed in parallel through a tree reduction. Our algorithm, called Tree Attention, for parallelizing exact attention computation across multiple GPUs enables cross-device decoding to be performed asymptotically faster (up to 8x faster in our experiments) than state-of-the-art approaches such as Ring Attention, while also requiring significantly less communication volume and incurring 2x less peak memory. We demonstrate that Tree Attention speeds up decoding up to 4x on Llama 3.1-8B and can be applied to a variety of hardware and networking setups such as H100 DGX nodes, AMD MI300x nodes, and PCIe connected NVIDIA RTX 4090s. Our code is publicly available here: https://github.com/Zyphra/tree_attention

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes