<aside> 💡

如果在本文中发现一些问题或者逻辑错误,欢迎随时评论或者通过[email protected]联系到我。

发布于: 2025年2月6日

最后修改: 2025年2月6日

</aside>

引言

排序是计算机领域最常见的计算任务之一。在本文中,我们分析了 PyTorch 框架中 torch.sort 算子在 GPU 上的性能优化点。以 A800 GPU 为例,使用 PyTorch 2.6.0(CUDA 12.6)对 1 亿个 int32 类型整数排序耗时 14.4 ms。A800 的理论内存带宽为 2 TB/s,从GPU全局内存(global memory)读取 1 亿个 int32 整数的理论时间为 0.2 ms。因此,该排序操作的时间相当于 72 次内存读取,或者 36 次完整的内存读写

为了评估这个排序时间是否合理,我们使用 PyTorch Profiler 分析 torch.sort 算子的内部实现,将端到端排序时间拆解到各个 GPU kernel,并借助 Nsight Compute 进一步分析关键 kernel 的性能,发现了多个优化点。除了常见的冗余数据拷贝等问题,我们特别注意到 DeviceRadixSortOnesweep 等核心 GPU kernel 在 A800 上的带宽利用率低于 40%。这一性能瓶颈与 CUDA 处理 PyTorch OpaqueType 类型输入数据的方式密切相关。

针对上述性能问题,我们进行了一系列优化。在 A800、H800、H20 等 GPU 上,torch.sort 性能均提升35% 以上。以 A800 为例,1 亿个 int32 整数的排序时间从 14.4 ms 优化至 6.4 ms。此外,torch.unique 等基于排序的算子同样受益于这些优化,性能显著提升。

本文的主要目标是拆解和分析这些优化点,并得出 PyTorch 相关的优化 commit,以便后续提交 PR 或作为进一步优化的参考。同时,对于 PyTorch 框架广泛使用的 OpaqueType 类型,我们通过分析 GPU PTX 指令,解释其对 kernel 性能的影响,以便在未来的算子实现中规避类似性能问题。

主要结果

表格1:本文的主要优化结果,排序输入为一亿个int32整数,原始性能对应PyTorch版本2.6.0+cuda12.6

GPU 原始性能 (ms) 优化后性能 (ms)
A800 14.2 6.4
H800 6.7 4.1
H20 10.8 6.0

注1:A800和H800是面向中国的特殊定制版GPU,和A100、H100差异主要是NVLink的性能,单机算子的性能基本一致。

注2:文章中无论是官方PyTorch还是源码编译的优化版本均基于cuda12.6,PyTorch 2.6.0官方pip安装默认基于cuda12.4,性能相比12.6会有轻微的下降。

注3:这篇文章只研究shape为(N, ) 的一维tensor的排序,shape为(N, T)这种基于SegmentSort的二维tensor排序是另一个优化方向,会在另一篇文章中单独分析。

PyTorch当前实现分析

接口定义

下面是PyTorch中torch.sort的算子定义:

# torch.sort
sorted, indices = torch.sort(input, dim=-1, descending=False, stable=False, out=None)

其中各参数含义如下: