<aside> 💡

如果在本文中发现一些问题或者逻辑错误,欢迎随时评论或者通过[email protected]联系到我。

发布于: 2025年11月12日

最后修改: 2025年11月12日

</aside>

这篇笔记中,我们讨论如何对DeltaNet做序列并行,这里面考虑的是单个GPU上不同SM间的序列并行,多GPU间的序列并行也是同理。

对于DeltaNet来说,主要包括如下的计算过程:

  1. WY表示计算
  2. Recurrent的方式更新状态 $\mathbf S$
  3. 计算输出 $\mathbf O$

其中步骤1和3本身没有recurrent依赖,序列之间的不同chunk是完全可以并行的,真正有recurrent依赖的是步骤2。下面是步骤2的计算公式(来自博客 DeltaNet Explained (Part II) 这篇文章的 Chunkwise Parallel Form for DeltaNet 这一节)

$$ \begin{align*} \mathbf{S}{[i+1]} &= \mathbf{S}{[i]} (\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]}) + \mathbf{U}{[i]}^\top \mathbf{K}{[i]} \\ &= \mathbf{S}{[i]} + \left(\mathbf{U}{[i]} - \mathbf{W}{[i]}\mathbf{S}{[i]}^\top\right)^\top \mathbf{K}_{[i]} && \in \mathbb{R}^{d\times d} \end{align*} $$

**其中 $\mathbf{S}{[i]} := \mathbf{S}{iC} \in \mathbb{R}^{d \times d}$ 代表第 $i$ 个chunk的初始state。

现在来看如何对DeltaNet做序列并行,考虑一个简单的例子,序列长度是8192,DeltaNet计算的Chunkwise计算对应的chunk大小是64,因此一共会切分成128个chunk,对应的初始states分别是 $\mathbf{S}{[0]}, \mathbf{S}{[1]}, \mathbf{S}{[2]}, ..., \mathbf{S}{[127]}$。如果不做序列并行,这128个states就会根据上面的公式在同一个SM上recurrent的执行,通过$\mathbf{S}{[0]}$ **计算出 *$\mathbf{S}{[1]}$,依次向前计算。 现在假设要做 CP=2 的序列并行,需要在2个SM上执行计算。那么就需要拆成两个子序列,每一个子序列的长度是 4096,其中 $\mathbf{S}{[0]}, \mathbf{S}{[1]}, ..., \mathbf{S}_{[63]}$,在第一个SM上计算,$\mathbf{S}{[64]}, \mathbf{S}{[65]}, ..., \mathbf{S}{[127]}*$ 在第二个SM上计算。主要问题是第二个SM上的所有states计算都依赖于 $\mathbf{S}{[64]}$。在序列并行方式下,在计算第二个SM的时候,并没有 $\mathbf{S}_{[64]}$ 这个真实states,只有一个不完整的states $\mathbf{S}{[64]}^{\Delta}$ **,在实际中这个$*\mathbf{S}{[64]}^{\Delta}*$ 可能就是一个全0的states,但是为了描述清晰,仍然用$\mathbf{S}_{[64]}^{\Delta}$ 来代指。我们有:

$$ \begin{align*} \mathbf{S}{[i+1]} &= \mathbf{S}{[i]} (\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]}) + \mathbf{U}{[i]}^\top \mathbf{K}{[i]} \\ \mathbf{S}{[i+1]}^{\Delta} &= \mathbf{S}{[i]}^{\Delta} (\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]}) + \mathbf{U}{[i]}^\top \mathbf{K}{[i]} \end{align*} $$

那么很容易得到, $\mathbf{S}{[i+1]} - \mathbf{S}{[i+1]}^{\Delta} = (\mathbf{S}{[i]} - \mathbf{S}{[i]}^{\Delta}) (\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]}$) 。同样可以得到$\mathbf{S}{[i+2]} - \mathbf{S}{[i+2]}^{\Delta} = (\mathbf{S}{[i]} - \mathbf{S}{[i]}^{\Delta}) (\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]})(\mathbf{I}-\mathbf{W}{[i+1]}^\top \mathbf{K}{[i+1]})$ ,或者更一般的公式:

$$ \mathbf{S}{[i+k]} - \mathbf{S}{[i+k]}^{\Delta} = (\mathbf{S}{[i]} - \mathbf{S}{[i]}^{\Delta}) \prod_{j=0}^{k-1}(\mathbf{I}-\mathbf{W}{[i+j]}^\top \mathbf{K}{[i+j]}) $$

,很容易得到如下的序列并行方式:

  1. 第一个SM负责计算第一个子序列,这部分的计算逻辑和之前保持一致;
  2. 第二个SM负责计算第二个子序列,但是初始state为 $\mathbf{S}{[64]}^{\Delta}=0$;然后recurrent的计算 $*\mathbf{S}{[i+1]}^{\Delta}=\mathbf{S}{[i]}^{\Delta} + \left(\mathbf{U}{[i]} - \mathbf{W}{[i]}\mathbf{S}{[i]}^\top\right)^\top \mathbf{K}{[i]}$;同时也初始化一个scaling矩阵 $\mathbf M{[64]}=\mathbf I*$,$\mathbf M_{[i+1]}=\mathbf M_{[i]}(\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]})$。将 $\mathbf{S}{[i+1]}^{\Delta}$ **和 $*\mathbf M{[i+1]}*$ 存储到global memory中。
  3. 当第一个SM计算完成之后,就得到了真实的 $\mathbf{S}{[64]}$,那么就可以通过 *$\mathbf S{[i]}=\mathbf S_{[i]}^{\Delta} +\mathbf S_{[64]}\mathbf M_{i}$*来计算出来最终的$\mathbf S_{[i]}$。需要说明的是,这一步并不是recurrent的,可以和$\mathbf O$的计算fuse到一起。

上述计算的主要问题是,会额外增加 $\mathbf M_{[i]}$ 的从global memory的写出和读入两次操作。对于CP=4 的情况也是类似的,一共使用4个SM计算,每一个SM处理32个chunk。

  1. 第一个SM负责计算第一个子序列,这部分的计算逻辑和之前保持一致;
  2. 另外三个SM负责计算后三个子序列,但是初始state为 $\mathbf{S}{[32]}^{\Delta}=\mathbf{S}{[64]}^{\Delta}=\mathbf{S}{[96]}^{\Delta}=0$;然后在每一个SM上recurrent的计算 $*\mathbf{S}{[i+1]}^{\Delta}=\mathbf{S}{[i]}^{\Delta} + \left(\mathbf{U}{[i]} - \mathbf{W}{[i]}\mathbf{S}{[i]}^\top\right)^\top \mathbf{K}{[i]}$;同时也初始化一个scaling矩阵 $\mathbf M{[32]}=\mathbf M_{[64]}=\mathbf M_{[96]}=\mathbf I*$,$\mathbf M_{[i+1]}=\mathbf M_{[i]}(\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]})$。将 $\mathbf{S}{[i+1]}^{\Delta}$ **和 $*\mathbf M{[i+1]}*$ 存储到global memory中。