<aside> 💡

If you find any mistakes in this article, please feel free to contact me through comment or [email protected].

Published: October 18, 2025

Last modified: October 18, 2025

</aside>

Introduction.

FlashAttention adopts the Flash-Decoding technique during the decode stage. To fully exploit GPU parallelism when computing attention, the KV cache for a single request is split into multiple sub-sequences. Each sub-sequence is processed on a different GPU SM with the same query, and a final combine kernel merges the attention results across sub-sequences. Within the same decode batch, KV cache lengths vary by request, so the split strategy must be computed dynamically per request. This post first analyzes the dynamic split strategy in FlashAttention-3 on Hopper, and shows through simulations and experiments on real GPUs that the current strategy is prone to wave quantization, which can degrade decode performance—especially under certain request-length distributions. To address this, we propose an improved split strategy. Through simulations and real-world GPU benchmarks, we show the new strategy often outperforms the current FlashAttention approach and can deliver up to 30% speedups for common request distributions, with essentially no performance-regression bad cases.

Concretely, this post aims to:

  1. Introduce the algorithm FlashAttention-3 uses to split along the sequence dimension during decoding, and explain the rationale behind its design.
  2. Using simulations and measurements, show that the current strategy can easily trigger performance issues due to wave quantization—especially when the decode batch is relatively large (e.g., on H20 when batch > 25) and per-request lengths are relatively uniform within the batch.
  3. Within the FA3 framework, present an optimization that greedily re-splits some sub-sequences to reduce SM usage. This can be implemented efficiently on GPU with about ~2 µs overhead and, under many input distributions, yields noticeable acceleration. In real GPU tests we observed up to 33% improvement (e.g., 680 µs down to 453 µs).

Terminology and Problem Definition

Unless otherwise specified, “attention” refers to GQA with 8 query heads, 1 KV head, and head_dim = 128, corresponding to a single-GPU configuration for Llama3-70B in a TP8 deployment. Our analysis and performance tests primarily target H20, which has 78 SMs (Streaming Multiprocessors). We also analyze H800 in some cases (132 SMs). Since we focus on Hopper, all analysis is based on FlashAttention-3 (FA3).

Assume the decode batch size is $B$, and the KV cache length for each request is $T_i, i ≤ B$ . In FA3, for hardware efficiency, both the query and KV cache are loaded in block-sized units. On the KV dimension, one block corresponds to 176 tokens (defined in tile_size.h). Each request thus has $L_i=\lceil \frac{T_i}{176} \rceil$ blocks.

In many cases, $B$ is smaller than the number of GPU SMs. Assigning one SM per request would waste SM resources. Flash-Decoding therefore splits each request’s KV cache along the sequence dimension into multiple sub-sequences, runs each sub-sequence on a separate SM, and tries to saturate GPU parallel compute resources.

This raises the question: how should we split a sequence, i.e., how many sub-sequences should each request’s KV cache be divided into? Consider a simple example where all requests in the batch have identical sequence lengths. In one decode batch, there are 32 requests, each with KV cache length 4096 (24 blocks). Two splitting strategies are:

  1. Split each sequence into 2 sub-sequences. Each sub-sequence is 2112 tokens (12 blocks). Total sub-sequences = 64, mapping to 64 SMs, with each SM processing 12 blocks.
  2. Split each sequence into 3 sub-sequences. Each sub-sequence is 1408 tokens (8 blocks). Total sub-sequences = 96, mapping to 96 SMs (2 GPU waves). The GPU has only 78 SMs, so $96 - 78 = 18$ sub-sequences fall into the second wave, which has low utilization. In that second wave, 18 SMs each process $8 \times 2 = 16$ blocks. This is also FA3’s default split strategy; we will analyze why FA3 designed it that way.

From the above, the maximum blocks per SM in the first strategy is 75% of that in the second. Using the latest FlashAttention code (10/15 commit) for this test, the first strategy takes 93 µs while the second takes 118 µs. We can conclude:

  1. Different split strategies have real and significant performance impact.
  2. FA3’s default implementation triggers wave quantization: the number of sub-sequences slightly exceeds the number of GPU SMs, creating a low-utilization second wave.

Therefore, the goal of sequence splitting in Flash-Decoding is to determine the number of sub-sequences per request length to saturate SMs while avoiding wave-quantization-induced slowdowns. Mathematically, splitting can cause load imbalance across SMs—some SMs process more blocks than others. The optimization target is to minimize the maximum number of blocks processed by any SM.

Analysis

FA3’s Current Computation Logic