<aside> 💡
If you find any mistakes in this article, please feel free to contact me through comment or [email protected].
Published: October 18, 2025
Last modified: October 18, 2025
</aside>
FlashAttention adopts the Flash-Decoding technique during the decode stage. To fully exploit GPU parallelism when computing attention, the KV cache for a single request is split into multiple sub-sequences. Each sub-sequence is processed on a different GPU SM with the same query, and a final combine kernel merges the attention results across sub-sequences. Within the same decode batch, KV cache lengths vary by request, so the split strategy must be computed dynamically per request. This post first analyzes the dynamic split strategy in FlashAttention-3 on Hopper, and shows through simulations and experiments on real GPUs that the current strategy is prone to wave quantization, which can degrade decode performance—especially under certain request-length distributions. To address this, we propose an improved split strategy. Through simulations and real-world GPU benchmarks, we show the new strategy often outperforms the current FlashAttention approach and can deliver up to 30% speedups for common request distributions, with essentially no performance-regression bad cases.
Concretely, this post aims to:
Unless otherwise specified, “attention” refers to GQA with 8 query heads, 1 KV head, and head_dim = 128, corresponding to a single-GPU configuration for Llama3-70B in a TP8 deployment. Our analysis and performance tests primarily target H20, which has 78 SMs (Streaming Multiprocessors). We also analyze H800 in some cases (132 SMs). Since we focus on Hopper, all analysis is based on FlashAttention-3 (FA3).
Assume the decode batch size is $B$, and the KV cache length for each request is $T_i, i ≤ B$ . In FA3, for hardware efficiency, both the query and KV cache are loaded in block-sized units. On the KV dimension, one block corresponds to 176 tokens (defined in tile_size.h). Each request thus has $L_i=\lceil \frac{T_i}{176} \rceil$ blocks.
In many cases, $B$ is smaller than the number of GPU SMs. Assigning one SM per request would waste SM resources. Flash-Decoding therefore splits each request’s KV cache along the sequence dimension into multiple sub-sequences, runs each sub-sequence on a separate SM, and tries to saturate GPU parallel compute resources.
This raises the question: how should we split a sequence, i.e., how many sub-sequences should each request’s KV cache be divided into? Consider a simple example where all requests in the batch have identical sequence lengths. In one decode batch, there are 32 requests, each with KV cache length 4096 (24 blocks). Two splitting strategies are:
From the above, the maximum blocks per SM in the first strategy is 75% of that in the second. Using the latest FlashAttention code (10/15 commit) for this test, the first strategy takes 93 µs while the second takes 118 µs. We can conclude:
Therefore, the goal of sequence splitting in Flash-Decoding is to determine the number of sub-sequences per request length to saturate SMs while avoiding wave-quantization-induced slowdowns. Mathematically, splitting can cause load imbalance across SMs—some SMs process more blocks than others. The optimization target is to minimize the maximum number of blocks processed by any SM.