Sequence-Aware Split Heuristic to Mitigate SM Underutilization in FlashAttention-3 Low-Head-Count Decoding
This addresses an incremental optimization problem for users of FlashAttention-3 in low-head-count decoding scenarios on Hopper GPUs.
The paper tackled the GPU occupancy bottleneck in FlashAttention-3 during low-head-count decoding by proposing a sequence-aware split policy, resulting in a 21-24% improvement in decoder kernel efficiency.
The standard FlashAttention-3 heuristic exhibits a GPU occupancy bottleneck in low-head-count decoding configurations because it disables sequence splitting based on sequence length alone, underutilizing the Streaming Multiprocessors of Hopper GPUs. Our proposed sequence-aware split policy mitigates this by allowing sequence-level parallelism in low-head-count regimes, improving hardware utilization to deliver roughly a 21 to 24% improvement in decoder kernel efficiency on metadata-enabled inference paths, with no observed regressions.