Matrix multiplication (GEMM) and attention are possibly the two most important operators in large model computations, and this article first discusses matrix multiplication. On GPUs, there is typically a gap between the actual computational performance of matrix GEMM and theoretical performance. In this article, we use MFU (Model Flops Utilization) to represent the ratio between actual FLOPS and theoretical GPU FLOPS for GEMM. According to previous studies [1][2][3], MFU of GEMM is affected by multiple factors, including GPU model, GPU actual power, matrix shape, wave quantization, and software implementations like cuBlas/CUTLASS. These numerous variables lead to a practical problem: the lack of reference benchmarks to evaluate whether GEMM computational performance is reasonable under specific GPU, software versions, and matrix shapes. In this article, we list a set of rules to verify the reasonableness of actual GEMM performance, which we call GEMM performance consistency rules.
Based on these consistency rules, we automatically identified a series of GEMM performance anomalies. Here are three representative cases:
The causes and solutions of these performance issues are not the focus of this article. The main purpose here is to demonstrate that through certain rules, we can discover performance inconsistency issues that actually occur in practice and have negative impacts. This article mainly discusses methods for automatic GEMM testing to discover these issues.
All experimental results and reproducible scripts in this article are available on github, and the experiments were conducted on a cloud service that provides GPU rentals to individuals.
In papers [1][2][3], the authors point out that due to tiling strategies and wave quantization effects, computation efficiency varies for different matrix shapes. However, in practical applications, we generally expect that when matrix A's shape is larger than matrix B's, matrix A's MFU should be roughly equal to or slightly higher than matrix B's, since larger shapes can better utilize GPU's parallel computing capabilities. If matrix A's MFU is significantly lower than matrix B's, we can reduce latency by splitting A into two smaller matrices for separate computation. We call this the consistency principle of GEMM MFU with respect to matrix shape.
We conducted a series of experiments on RTX4090, performing float16 matrix computations of shape [batch, 4096] * [4096, 14336]. The batch values were selected from [4, 8, 12, ..., 32768], totaling 8192 matrix computations. Here, [4096, 14336] corresponds to the up-proj and gate-proj matrix shapes of the Llama3-8B model. For each batch, we executed torch.matmul multiple times and measured the average kernel execution time on GPU, calculated the actual FLOPS, and compared it with RTX4090's theoretical performance to obtain 8192 GEMM MFU values. For each batch N, if the MFU of matrix [N, 4096] is lower than a certain ratio (such as 0.85) of any smaller matrix [N1, 4096] (N1 < N), we mark it as an abnormal batch. Through script analysis, we obtained the following results (to avoid hardware differences, we tested on two different machines). The results show that 35% of the 8192 batches have abnomal performance.
GPU | Threshold for Abnormal GEMM Detection | Number of Abnormal Batches | Total Number of Batches | Percentage of Abnormal Batches |
---|---|---|---|---|
RTX4090 Instane1 | 0.9 | 3021 | 8192 | 37% |
RTX4090 Instane1 | 0.85 | 2882 | 8192 | 35% |
RTX4090 Instane2 | 0.9 | 3135 | 8192 | 38.2% |
RTX4090 Instane2 | 0.85 | 2882 | 8192 | 35% |
To prevent measurement errors caused by GPU anomalies on specific machines, we perform two additional checks on the execution results. First, we select two different 4090 machines from different sources. Second, we manually sample multiple matrix computations and verify the performance of matrix calculations under different shapes. We found that although the performance of different 4090 GPUs shows some variability, the overall trend is similar.
In addition, based on this finding, for batches with abnormal GEMM computation performance, such as N, we search for all possible pairs of smaller matrices N1 and N2 where N1 + N2 = N, with shapes [N1, 4096] and [N2, 4096]. We then calculate the GEMM latency for each of the two matrices and sum the results. If the sum of the two latencies is noticeably smaller than the latency for batch N, we classify the matrix [N, 4096] as a splitwise_gemm optimizable matrix. The definition of splitwise_gemm is as follows: it calculates the original matrix multiplication through two separate matrix multiplications.
def splitwise_gemm(size_m, size_k, size_n, size_m_split):
A = torch.randn(size_m, size_k, dtype=torch.float16)
B = torch.randn(size_k, size_n, dtype=torch.float16)
C = torch.empty(size_m, size_n, dtype=torch.float16)
torch.matmul(A[0:size_m_split, :], B, out=C1[0:size_m_split, :])
torch.matmul(A[size_m_split:size_m, :], B, out=C1[size_m_split:size_m, :])
return C
At the same time, we tested the splitwise_gemm computation on several manually selected matrices and indeed achieved performance improvements. Details can be found in the table below.
batch | GEMM MFU | kernel time(us) | splitwise_gemm time(us) | two submatrices in splitwise_gemm |
---|---|---|---|---|
2180 | 0.7323 | 2106.2 | 1716.7 | [192, 4096], [1988, 4096] |
3076 | 0.7427 | 2930.2 | 2374.3 | [384, 4096], [2692, 4096] |
4228 | 0.7721 | 3874.0 | 3170.6 | [172, 4096], [4056, 4096] |
Note: The above issue with the RTX4090 only exists in the float16 type and does not occur in the bfloat16 type. The exact reason is still unclear.
Another surprising example occurred when we performed a [batch, 4096] * [4096, 4096] computation check on the A800, which corresponds to the o-proj of Llama3-8B. We found that for a batch size of 1024, there was a significant performance issue (MFU around 55%). By applying splitwise_gemm to split [1024, 4096] into [384, 4096] and [640, 4096], the matrix computation time was reduced from 191.8us to 148.8us. In fact, for bfloat16 type multiplication, it can be observed that when the batch size is within the range of 772 to 1388, or from 1668 to 2048, the GEMM MFU shows significant anomalies. For more details, please refer to the GitHub profiler results.