SystolicAttention: Fusing FlashAttention within a Single Systolic Array
This work addresses performance bottlenecks for transformer model acceleration on hardware, offering a domain-specific improvement for AI chip design.
The paper tackles the challenge of efficiently executing FlashAttention on systolic-array accelerators by proposing FSA, an enhanced systolic array architecture that runs FlashAttention entirely within the array, achieving 1.77x and 4.83x higher attention FLOPs/s utilization compared to AWS Neuron v2 and Google TPUv5e with 12% area overhead.
Transformer models rely heavily on scaled dot-product attention (SDPA), typically implemented using the FlashAttention algorithm. However, current systolic-array-based accelerators face significant challenges when executing FlashAttention. Systolic arrays achieve high utilization primarily for consecutive and large matrix multiplications, whereas FlashAttention requires frequent interleaving of matrix multiplications and softmax operations. The frequent data swaps between matrix multiplications on the systolic array and softmax operations on external units result in low array utilization. Moreover, when these computations run concurrently, the softmax stage contends with matrix multiplication for register file and SRAM ports, further degrading performance. To overcome these limitations, we propose FSA, an enhanced systolic array architecture that enables the FlashAttention algorithm to run entirely within a single systolic array, eliminating the need for external vector units. At the core of FSA is SystolicAttention, a novel scheduling algorithm that maps FlashAttention operations onto systolic arrays with fine-grained, element-wise overlap. This approach significantly improves array utilization while preserving the original floating-point operation order to maintain numerical stability. We implement FSA in synthesizable RTL and evaluate its performance against state-of-the-art commercial accelerators. Our results show that FSA achieves 1.77 and 4.83 times higher attention FLOPs/s utilization compared to AWS Neuron v2 and Google TPUv5e, respectively, with only 12% area overhead.