<aside> 💡
如果在本文中发现一些问题或者逻辑错误,欢迎随时评论或者通过[email protected]联系到我。
发布于: 2025年11月12日
最后修改: 2025年11月12日
</aside>
这篇笔记中,我们讨论如何对DeltaNet做序列并行,这里面考虑的是单个GPU上不同SM间的序列并行,多GPU间的序列并行也是同理。
对于DeltaNet来说,主要包括如下的计算过程:
其中步骤1和3本身没有recurrent依赖,序列之间的不同chunk是完全可以并行的,真正有recurrent依赖的是步骤2。下面是步骤2的计算公式(来自博客 DeltaNet Explained (Part II) 这篇文章的 Chunkwise Parallel Form for DeltaNet 这一节)
$$ \begin{align*} \mathbf{S}{[i+1]} &= \mathbf{S}{[i]} (\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]}) + \mathbf{U}{[i]}^\top \mathbf{K}{[i]} \\ &= \mathbf{S}{[i]} + \left(\mathbf{U}{[i]} - \mathbf{W}{[i]}\mathbf{S}{[i]}^\top\right)^\top \mathbf{K}_{[i]} && \in \mathbb{R}^{d\times d} \end{align*} $$
**其中 $\mathbf{S}{[i]} := \mathbf{S}{iC} \in \mathbb{R}^{d \times d}$ 代表第 $i$ 个chunk的初始state。
现在来看如何对DeltaNet做序列并行,考虑一个简单的例子,序列长度是8192,DeltaNet计算的Chunkwise计算对应的chunk大小是64,因此一共会切分成128个chunk,对应的初始states分别是 $\mathbf{S}{[0]}, \mathbf{S}{[1]}, \mathbf{S}{[2]}, ..., \mathbf{S}{[127]}$。如果不做序列并行,这128个states就会根据上面的公式在同一个SM上recurrent的执行,通过$\mathbf{S}{[0]}$ **计算出 *$\mathbf{S}{[1]}$,依次向前计算。 现在假设要做 CP=2 的序列并行,需要在2个SM上执行计算。那么就需要拆成两个子序列,每一个子序列的长度是 4096,其中 $\mathbf{S}{[0]}, \mathbf{S}{[1]}, ..., \mathbf{S}_{[63]}$,在第一个SM上计算,$\mathbf{S}{[64]}, \mathbf{S}{[65]}, ..., \mathbf{S}{[127]}*$ 在第二个SM上计算。主要问题是第二个SM上的所有states计算都依赖于 $\mathbf{S}{[64]}$。在序列并行方式下,在计算第二个SM的时候,并没有 $\mathbf{S}_{[64]}$ 这个真实states,只有一个不完整的states $\mathbf{S}{[64]}^{\Delta}$ **,在实际中这个$*\mathbf{S}{[64]}^{\Delta}*$ 可能就是一个全0的states,但是为了描述清晰,仍然用$\mathbf{S}_{[64]}^{\Delta}$ 来代指。我们有:
$$ \begin{align*} \mathbf{S}{[i+1]} &= \mathbf{S}{[i]} (\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]}) + \mathbf{U}{[i]}^\top \mathbf{K}{[i]} \\ \mathbf{S}{[i+1]}^{\Delta} &= \mathbf{S}{[i]}^{\Delta} (\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]}) + \mathbf{U}{[i]}^\top \mathbf{K}{[i]} \end{align*} $$
那么很容易得到, $\mathbf{S}{[i+1]} - \mathbf{S}{[i+1]}^{\Delta} = (\mathbf{S}{[i]} - \mathbf{S}{[i]}^{\Delta}) (\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]}$) 。同样可以得到$\mathbf{S}{[i+2]} - \mathbf{S}{[i+2]}^{\Delta} = (\mathbf{S}{[i]} - \mathbf{S}{[i]}^{\Delta}) (\mathbf{I}-\mathbf{W}{[i]}^\top \mathbf{K}{[i]})(\mathbf{I}-\mathbf{W}{[i+1]}^\top \mathbf{K}{[i+1]})$ ,或者更一般的公式:
$$ \mathbf{S}{[i+k]} - \mathbf{S}{[i+k]}^{\Delta} = (\mathbf{S}{[i]} - \mathbf{S}{[i]}^{\Delta}) \prod_{j=0}^{k-1}(\mathbf{I}-\mathbf{W}{[i+j]}^\top \mathbf{K}{[i+j]}) $$
,很容易得到如下的序列并行方式:
上述计算的主要问题是,会额外增加 $\mathbf M_{[i]}$ 的从global memory的写出和读入两次操作。对于CP=4 的情况也是类似的,一共使用4个SM计算,每一个SM处理32个chunk。