背景

矩阵乘法(GEMM)和注意力(Attention)可能是大模型计算中最重要的两个算子,这篇文章先谈矩阵乘法。在GPU上,矩阵GEMM的实际计算性能与理论性能之间通常存在差距。本文中我们使用MFU(Model Flops Utilization)来代表GEMM的实际FLOPS和GPU理论FLOPS的比值。根据过往的研究[1][2][3],矩阵计算的MFU受多个因素影响,包括GPU型号、GPU实际功率、矩阵shape、wave quantization以及cuBlas/CUTLASS等软件实现。这些众多变量导致了一个实际问题:缺乏可参考的基准指标来评估在特定GPU、软件版本和矩阵shape下的GEMM计算性能是否达到合理水平。在这篇文章里,我们列举一套规则来验证GEMM实际性能的合理性,将这些规则称为GEMM性能的一致性规则

基于这些一致性规则,我们通过程序自动识别出了一系列GEMM性能异常问题。以下是三个具有代表性的案例:

  1. 在RTX4090上,PyTorch 2.5.1+cuda12.4对于[batch, 4096] * [4096, 14336]的矩阵shape,相比PyTorch2.3.0+cuda12.1,在性能测试里出现了大范围的float16类型性能下降。这个shape对应Llama3-8B模型在TP=2下的up-proj矩阵计算,属于常见矩阵shape。
  2. 在A800 GPU上,PyTorch对[1024, 4096] * [4096, 4096]这个标准shape的GEMM计算效率异常低下,MFU仅为0.55。[4096, 4096]这个shape对应Llama3-8B模型的o-proj矩阵,而且1024是chunked-prefill常用的batch值,因此这个问题是实际中可能遇到的。
  3. 在Marlin kernel(vLLM推荐的int8 weight-only量化kernel)下,A100等Ampere架构GPU在batch>1024时的性能表现不佳。与相同shape的float16计算相比,GEMM计算时间增加了30%以上,部分情况下甚至增加50%。这与论文[4]中声称的性能差距小于10%不完全一致。最终我们发现了int4和int8量化的性能差异、vLLM Marlin版本迭代中的性能退化问题导致了这些不一致。

这些性能问题的原因和解决并不是这篇文章关注的重点,这里的主要目的是说明,通过一些规则,是可以发现一些在实际中真实出现且具有负向影响的性能不一致问题,本文主要是讨论一些自动发现这些问题的GEMM自动测试的方法。

文中所有的实验结果和可复现脚本在github上,实验都是在国内一家面向个人提供GPU租赁的云服务上完成。

一致性规则

MFU对于矩阵shape的一致性

在文章[1][2][3]中,作者指出由于tiling策略和wave quantization的影响,不同shape矩阵的计算效率会有所不同。然而在实际应用中,我们通常认为,当矩阵A的shape大于矩阵B时,矩阵A的MFU应该与矩阵B基本持平或略高,这是因为更大的shape能更充分地利用GPU的并行计算能力。如果矩阵A的MFU显著低于矩阵B,我们可以通过将A拆分为两个较小的矩阵分别计算来降低延迟。我们称为矩阵计算MFU对于矩阵shape的一致性原则

为了验证这一原则,我们在RTX4090上进行了实验,计算[batch, 4096] * [4096, 14336]的float16矩阵计算。这里的[4096, 14336]对应Llama3-8B模型的up-proj和gate_proj矩阵shape。我们选择了8192个矩阵计算,其中batch的值从[4, 8, 12, …, 32768]不等。对每个batch,我们多次执行torch.matmul并统计GPU上kernel的平均执行时间,据此计算实际FLOPS,再与RTX4090的理论性能对比,得到8192个GEMM MFU值(完整数据见profiler)。对于每个batch N,如果[N, 4096]矩阵的MFU低于任何较小矩阵[N1, 4096](N1 < N)计算效率的特定比例(如0.85),我们就将其标记为性能异常batch。通过脚本分析,我们获得了以下结果(为避免硬件差异带来的影响,选取了两个不同的机器分别实验)。从结果中,可以看出,8192个batch中有35%的的batch是有性能异常的。

GPU 异常GEMM检测阈值 异常batch N的个数 总batch个数 占总矩阵个数的比例
RTX4090实例1 0.9 3021 8192 37%
RTX4090实例1 0.85 2882 8192 35%
RTX4090实例2 0.9 3135 8192 38.2%
RTX4090实例2 0.85 2882 8192 35%

为了防止特定机器的GPU异常导致测量误差,我们对于执行结果执行两个额外检查,一个是选取了两个不同来源的4090机器,其次手动随机采样多个的矩阵计算来手动验证几个shape下的矩阵计算性能。我们发现虽然不同的4090 GPU的性能有一定波动性,但是整体趋势是类似的。

此外,基于这个发现,对于异常GEMM计算性能的batch,比如N,我们会搜索所有N1+N2=N的两个小矩阵[N1, 4096]和[N2, 4096],然后找到对应两个矩阵的GEMM性能,将二者求和,如果二者之和明显小于batch=N对应的矩阵计算延迟,我们称矩阵[N, 4096]为splitwise_gemm可优化矩阵,其中splitwise_gemm的定义如下,即通过两次矩阵乘法来计算原矩阵的乘法。

def splitwise_gemm(size_m, size_k, size_n, size_m_split):
    A = torch.randn(size_m, size_k, dtype=torch.float16)
    B = torch.randn(size_k, size_n, dtype=torch.float16)
    C = torch.empty(size_m, size_n, dtype=torch.float16)
    torch.matmul(A[0:size_m_split, :], B, out=C1[0:size_m_split, :])
    torch.matmul(A[size_m_split:size_m, :], B, out=C1[size_m_split:size_m, :])
    return C

同时对手动选取的几个矩阵来测试splitwise_gemm计算,我们确实实现了性能提升,详细可以参见如下的表格。

batch GEMM计算效率 计算时间 (us) splitwise_gemm计算时间(us) splitwise_gemm对应两个子矩阵
2180 0.7323 2106.2 1716.7 [192, 4096], [1988, 4096]
3076 0.7427 2930.2 2374.3 [384, 4096], [2692, 4096]
4228 0.7721 3874.0 3170.6 [172, 4096], [4056, 4096]

**注:**RTX4090的上述问题只在float16类型中存在,在bfloat16类型不存在,具体原因还不清楚。

另一个令我比较意外的例子,我们在A800上执行[batch, 4096] * [4096, 4096]的计算检查,这个矩阵shape对应Llama3-8B的o-proj,我们发现在batch=1024的情况下有明显的性能问题(MFU大约55%),通过splitwise_gemm将[1024, 4096] ,拆成[384, 4096] 和[640, 4096],矩阵计算的时间从191.8us优化到148.8us。事实上,对于bfloat16类型的乘法,可以发现当batch在772到1388的范围区间内,或者1668到2048的范围区间内,GEMM的MFU都是有明显异常的,详细可以参见Github测试结果