<aside> 💡

如果在本文中发现一些问题或者逻辑错误,欢迎随时评论或者通过[email protected]联系到我。

发布于: 2025年10月18日

最后修改: 2025年10月18日

</aside>

引言

FlashAttention 在decode阶段使用了 Flash-Decoding 技术,即为了充分利用GPU的并行度,在计算attention的时候,会将同一个请求的KVCache切分成多个子序列,每个子序列在不同的GPU SM上和相同的query进行计算,最后通过一个combine kernel将不同子序列的attention计算结果合并。在同一个decode batch中,不同请求的KVCache长度不同,因此需要动态地根据每个请求的长度计算切分策略。在这篇文章中,首先会分析当前FlashAttention-3在Hopper架构上的动态切分(dynamic split)策略,并且通过程序模拟分析和实测数据说明当前的策略很容易出现wave quantization导致的decode性能下降问题;包括为什么在特定请求分布的情况下更容易出现性能下降问题。为了解决这个问题,这篇文章提出了一个改进的切分策略,通过模拟实验和实际算子性能测试说明新的策略在很多情况下性能都好于FlashAttention当前的策略,在一些常见请求分布下能够实现30%的性能加速,并且基本上没有出现性能下降的bad case。

具体来说,这篇文章主要目的是说明如下的分析和结论:

  1. 介绍当前FlashAttention-3对于decode阶段序列维度进行切分的算法,以及分析算法设计的背后逻辑;
  2. 基于模拟和实测数据,说明当前切分策略很容易触发由于wave quantization问题导致的性能问题,并且这个性能问题在请求batch相对较大(比如H20上batch>25),batch里请求长度比较平均的情况下非常容易出现。
  3. 在FlashAttention-3现有框架内,给出优化这个问题的一个可能方法,即贪心的对部分子序列重新切分降低SM使用,这个方法可以在GPU上有效的实现,性能开销大约2 µs左右,但是可以在很多输入分布下带来比较明显的性能提升,在实际GPU测试的一些输入上可以提升33%(比如680 µs 优化到 453 µs)。

术语说明与问题定义

首先,如果不特别说明,下文中提到的attention结构都是指GQA,包括8个query head,1个KV head,head_dim是128,这个对应Llama3-70B模型在TP8部署下的单卡配置。此外,文章里的理论分析和性能测试主要是基于H20进行分析,H20的SM(Streaming Multiprocessor)个数是78;在特定情况会对H800的性能做分析,H800的SM个数是132。由于是针对Hopper架构,所以这里的所有分析都是基于 FlashAttention-3(下文用FA3来简称)来进行。 假设一次decode里请求的batch个数为 $B$,每个请求的KVCache长度是 $T_i, i ≤ B$。在FA3中,为了GPU硬件效率,无论是query还是KVCache,都是按照一个block为最小单位的方式进行加载;实际代码在KV维度按照176 token为一个block(定义在tile_size.h里面),那么每个请求对应 $L_i=\lceil \frac{T_i}{176} \rceil$个block。 由于在很多情况下,batch $B$ 都是小于 GPU 上SM的个数的,因此使用一个SM处理一个请求,会导致SM资源浪费,所以在 Flash Decoding 中,会将请求的KVCache按照序列维度进行切分,拆成多个子序列,每个子序列在一个单独的SM上执行,尽量用满GPU的并行计算资源。 这样就会面临一个问题,如何对序列长度进行切分,每个请求的KVCache要拆成几个子序列最合适?举一个简单的例子,假设batch内请求长度全都相同的情况,即在同一个decode batch中,一共32个请求,每个请求的KVCache序列长度都是 4096(对应 24 个block)。那么有两种切分策略:

  1. 第一种方案是每个序列拆成2个子序列,每个子序列对应2112个token(12个block),一共64个子序列,对应64个SM,每个SM处理12个block;
  2. 第二种方案是每个序列拆成3个子序列,每个子序列对应1408个token(8个block),一共96个子序列,对应96个SM(2个GPU wave)。由于 GPU 仅有78个SM,因此有 (96-78=18) 个子序列落在第二个wave,第二个wave的利用率偏低,第二个wave的18个SM均分别需要处理8 × 2 = 16个block。这个也是FA3的默认切分策略,后面会分析FA3设计这种算法的原因。

从上面的例子,我们可以看到第一种方案中SM需要处理的最大block个数是第二种的75%,在实际最新的FlashAttention代码(对应10.15日commit )跑的测试中,第一个方案需要93 µs,第二个方案需要118 µs,因此至少可以得到如下的结论:

  1. 不同的切分方式会真实且明显的影响实际性能;
  2. 在FA3的默认实现中,触发了wave quantization问题,即SM工作划分方式导致了任务SM使用量略高于GPU的SM总数,第二个wave利用率低。

所以Flash Decoding的序列切分问题,从计算角度目标就是根据每个请求的实际长度,计算出来切分出的子序列个数,尽量占用满GPU的SM资源,同时要避免出现wave quantization问题导致的性能下降。从数学上,切分策略通常会导致SM负载不均衡问题,即一个SM(或者几个)处理的block数量高于其他的SM,优化目标即找到一个最好的策略来最小化负载最大的SM需要处理的工作。

分析

FA3当前的计算逻辑

接下来介绍FA3的动态切分策略,总体逻辑如下(对应FA3代码中的prepare_varlen_num_blocks_kernel函数):

  1. 首先计算所有请求的block之和:total_blocks,$total\{blocks} = \sum{i}^{B}L_i$。
  2. 根据GPU上的SM个数,计算平均每个SM上需要划分到多少block,$blocks\per\sm=\left\lceil\frac{total\{blocks}*1.1}{num\{sm}}\right\rceil$,这里面1.1是一个margin,具体作用我们稍后分析。