<aside> š”
This is the English translation of my Chinese blog post originally published on May 25, 2025.
</aside>
Recently I have been reading this series of papers: DeltaNet[1], Gated DeltaNet[2], and RWKV-7[3]. DeltaNet is a form of linear attention that updates a state matrix via the delta rule, significantly improving expressivity over the basic linear attention[4], which can only incrementally add to the state matrix. Gated DeltaNet introduces decay to the state matrix on top of DeltaNet. RWKV-7 further extends this by promoting decay and in-context learning rates from scalars to vectors, and by generalizing the delta rule to a more universal update. Combined with optimizations in the RWKV family like token-shift, overall model expressivity improves further.
For me personally, this line of work marks a turning pointāfrom skepticism about linear attention to believing it is viable. Iāve therefore been thinking about how to train and optimize inference for such models in practice. There are already papers on accelerating structures like DeltaNet and Mamba 2 on GPUs[5][6], but fewer articles analyze the details from an inference framework perspective. This post collects some recent thoughts and questions. Using DeltaNet as an example, Iāll discuss practical deployment considerations for linear attention models from the perspective of inference frameworks, including:
For a detailed introduction to DeltaNetāincluding derivations and GPU accelerationāsee the paper[5] or blog posts[7][8][9]. The reference implementation is in the Flash Linear Attention repository[10]. For completeness, here is a brief recap.
Parallel form of linear attention: For standard transformers, softmax attention is computed as $O=softmax((QK^T) \odot M ) V$, where $Q$ is query, $K$ is key, $V$ is value, and $M$ is the causal mask. Linear attention removes softmax, yielding $O=((QK^T) \odot M ) V$, which is the parallel form. Importantly, because of the causal mask $M$, you cannot simply re-associate multiplications to compute $K^TV$ first, even though doing so would greatly reduce compute.
Recurrent form of linear attention: Define the recurrent form used during decode. Suppose weāve decoded $tā āā 1$ tokens and are computing the $t$-th token. Let $K_t,V_t$ be the keys and values for the first $t$ tokens, each of shape $(t,āD)$, where $D$ is the head dimension. Define $S_t=V_t^TK_t$. Then we can derive $S_t=S_{tā1}+v_t^Tk_t$, where $k_t$and $v_t$ are the current tokenās key and value vectors, with shape $(1,āD)$. $S_t$ has shape $(D,āD)$. The attention result is $o_t=q_tā S_t$. One can show that this yields the same $O$ as the parallel form. The update $S_t=S_{tā1}+v_t^Tk_t$ is called the recurrent form of linear attention, and $S_t$ is the state matrix. During decode, linear attention stores not a KVCache but this state matrix.
DeltaNet update: DeltaNet modifies the update to $S_t=S_{tā1}(Iā\beta_tk_t^Tk_t)+β_tv_t^Tk_t$. Thus the state matrix is still $(D,āD)$. In practice there are multiple heads, each with its own state matrix, so the total number of elements is $HāDāD$.
In what follows, let $T$ be the context length, $HQ$ be the number of query heads, $HK$ the number of key/value heads, $D$ the per-head dimension (typically 64, 128, or 256), and $C$ the chunk size used by DeltaNetās chunked computation (set to 64 by default in Flash Linear Attention[10]).
This section analyzes sequence length thresholds at which DeltaNet uses less storage than MHA, GQA, MLA, and other softmax-attention variants. While DeltaNet stores a state matrix, Iāll keep using the term āKVCacheā for consistency. To distinguish settings, let $HK_{softmax}, D_{softmax}$ and $HK_{delta}, D_{delta}$ denote parameters for softmax attention and for DeltaNet, respectively. It is straightforward to derive the sequence-length threshold where DeltaNet has an advantage: $T>\frac{D_{delta}D_{delta}{HK_{delta}}}{2*D_{softmax}*HK_{softmax}}$.
For standard multi-head attention and assuming equal parameter budgets (i.e., $D, HK, HQ$ are the same), this simplifies to $T > \frac{D}{2}$ . If $D=128$, then when $T > 64$ DeltaNet uses less storage. Reality is more complex: