<aside> 💡
注:线性注意力并不是我熟悉的领域,也没有相关的生产化落地经验,这里写的所有东西,主要是业余时间读技术报告、博客、以及一些开源代码后做一些计算得到的,并不是专业性的意见,更多的是因为相关领域资料和数据并不是很多,所以列举一下自己能想到的问题听一些更专业的意见反馈。文章里面有可能包含有个人的偏见和错误理解,如果有发现不对的地方请告诉我来修改。
发布于: 2025年5月25日
最后修改: 2025年5月25日
English version: DeltaNet-from-the-Inference-Framework-Perspective
</aside>
最近在读DeltaNet[1],Gated DeltaNet[2]、RWKV-7[3]这一系列文章。DeltaNet是linear attention的一种实现,通过delta rule来对状态矩阵(state matrix)进行更新,相比只能对状态矩阵进行简单增量累加的基础版linear attention[4]显著提升了模型表达能力。Gated DeltaNet在DeltaNet的基础上引入了状态矩阵的decay机制;RWKV-7在此基础上进一步扩展,将decay、in-context learning rates等参数从scalar提升成向量,同时将delta rule的更新策略推广到更为通用的形式,结合RWKV系列模型原本就有的token-shift等优化策略,整体模型表达能力得到了进一步提升。
从个人角度,这一系列工作是我对linear attention从比较疑虑到开始觉得可行的一个分水岭 ,因此平时会开始思考一下这种模型在实际中应该如何做训练和推理优化。当前在GPU上对DeltaNet、Mamba 2等结构进行加速的文章可以参考[5][6],但是从推理框架视角进行细节分析的文章并不多,因此在这里对最近的一些想法和遇到的疑问进行整理。具体来说,这篇文章会以DeltaNet为例,从推理框架的角度分析linear attention模型在实际部署中需要注意的一些问题,包括如下几个方面:
关于DeltaNet的详细介绍,包括算法推导以及GPU上性能加速可以参见论文[5],或者博客文章[7][8][9],DeltaNet的源码实现参见Flash Linear Attention的代码仓库[10]。在这里,为了文章的完整性给出简单的介绍。
线性注意力的parallel form公式:标准的transformer模型对应的softmax attention计算公式是: $O=softmax((QK^T) \odot M ) V$ ,其中 Q 代表query, $K$ 代表key, V 代表value, M 是causal mask矩阵。Linear attention将softmax计算去掉,得到 $O=((QK^T) \odot M ) V$ ,这个公式就是linear attention的parallel form计算公式。需要特别说明的是,由于causal mask M的存在,所以无法简单的直接通过结合律先计算 $K^TV$,即使这种交换能大幅度降低计算成本。
线性注意力的recurrent form公式:下面定义标准linear attention在decode阶段的recurrent form计算公式。假设当前已经完成 $t−1$ 个token的decode,现在需要计算第 $t$ 个token的结果。定义 $K_t,V_t$ 分别代表前 $t$ 个token对应的key和value,shape均为 $(t,D)$ ,其中 $D$ 是head_dim。定义 $S_t=V_t^TK_t$ ,那么可以推导出来 $S_t=S_{t−1}+v_t^Tk_t$ ,其中$k_t$ 和 $v_t$ 分别代表当前token的key和value向量,shape为 $(1,D)$ , St 的shape是 $(D,D)$ 。最终的attention计算结果为 $o_t=q_t⋅S_t$ ,可以证明这种方式计算出的 $O$ 和parallel form公式是一致的。公式 $S_t=S_{t−1}+v_t^Tk_t$ 被称为linear attention的recurrent form,其中 $S_t$ 被称为state matrix。在decode阶段,linear attention存储的不是KVCache,而是这个state matrix。
DeltaNet的计算公式:DeltaNet相比标准的linear attention做了一些改进,主要是将 St 的更新公式改成了 $S_t=S_{t−1}(I−\beta_tk_t^Tk_t)+β_tv_t^Tk_t$ ,因此state matrix同样是 $(D,D)$ 的方阵,当然在实际中还有多个head,每一个head都对应一个独立的state matrix,所以实际中state matrix的元素个数是 $H∗D∗D$ 。
在接下来,使用 $T$ 代表请求的上下文长度, $HQ$ 代表query的head个数, $HK$ 代表key和value的head个数,$D$ 代表每一个head的维度(通常是64、128或者256), $C$ 代表DeltaNet里面按照chunk进行计算对应的chunk大小(Flash Linear Attention[10]代码中设置默认值为64)
这一节主要分析在什么样的序列长度下,DeltaNet相比MHA、GQA、MLA等softmax attention变种使用更少的存储空间。DeltaNet存储的是state matrix,但是为了统一术语这里仍然用KVCache。为了区分不同结构,我们使用 $HK_{softmax}, D_{softmax}$ 和 $HK_{delta}, D_{delta}$ 来分别代表softmax attention和DeltaNet的参数。很容易计算得到,DeltaNet相比softmax attention更有优势的序列长度阈值是:$T>\frac{D_{delta}D_{delta}{HK_{delta}}}{2*D_{softmax}*HK_{softmax}}$ 。
对于Multi-head Attention的情况比较简单,假设softmax attention和DeltaNet使用相同的参数量(即 $D, HK, HQ$ 等参数二者保持相等),那么直接可以简化到 $T > \frac{D}{2}$ ,假设 $D=128$ 则 $T > 64$ 的时候DeltaNet的存储空间更有优势。但是实际情况要更复杂: