FlatAttention: Dataflow and Fabric Collectives Co-Optimization for Large Attention-Based Model Inference on Tile-Based Accelerators
This work addresses the critical bottleneck of attention computation in large model inference for AI hardware developers, offering significant performance improvements but is incremental as it builds on existing tile-based accelerator architectures.
The paper tackles the problem of inefficient attention computation during inference for large mixture-of-experts models on tile-based accelerators by proposing FlatAttention, a dataflow that co-optimizes with fabric collectives to reduce memory bottlenecks, achieving up to 92.3% utilization, 4.1x speedup over FlashAttention-3, and 16x lower HBM traffic.
Attention accounts for an increasingly dominant fraction of total computation during inference for mixture-of-experts (MoE) models, making efficient acceleration critical. Emerging domain-specific accelerators for large model inference are shifting toward chip-scale and wafer-scale tile-based architectures. Tiles contain large matrix and vector engines and are connected through on-chip interconnects, which support tile-to-tile traffic to reduce the tile-to-main-memory traffic bottleneck. Hence, dataflow management is crucial to achieve high utilization. We propose FlatAttention, a dataflow for modern attention variants on tile-based accelerators. FlatAttention minimizes expensive high-bandwidth memory (HBM) accesses by exploiting collective primitives integrated into the on-chip network fabric, achieving up to 92.3% utilization, 4.1x speedup over FlashAttention-3, and 16x lower HBM traffic. On a 32x32 tile configuration with peak performance comparable to NVIDIA GH200, FlatAttention generalizes across multiple attention variants, achieving an average of 86% utilization for compute-bound attentions and 78% HBM bandwidth utilization for memory-bound ones, resulting in an average 1.9x speedup over attention implementations on GH200. Finally, we evaluate end-to-end DeepSeek-v3 FP8 decoding with FlatAttention on a wafer-scale multi-die system, achieving a 1.9x improvement in system throughput and a 1.4x reduction in per-user token output latency, despite operating with 1.5x lower peak system performance compared to the state-of-the-art solution.