如果在本文中发现一些问题或者逻辑错误,欢迎随时评论或者通过[email protected]联系到我。

开始于:June 22 2023

发布于: June 23 2023

最后修改: June 23 2023

Other versions: English

本文的主要结论

本文主要通过线性代数的推导,证明了对于使用了ALiBi embedding的模型(如Bloom和MPT-Instruct),以及使用了 absolute position embeddings (如GPT模型),可以通过改变矩阵计算的顺序,在保持结果一致的情况下,相比原生的multi-head attention的计算复杂度显著下降。当前业界针对原生的attention计算的主流优化方式是通过KVCache,我们的方法相比KVCache的方法显存占用只有一半,同时计算性能差距非常小。

对于大语言模型来说,在序列长度达到4K这个量级之后,KVCache的显存占用已经比模型的参数量还要大,显存成了继续扩大context大小的主要瓶颈。我们的算法相比基于KVCache的multi-head attention只使用了一半的Cache显存空间,因此可以将模型的最大context大小接近提升一倍。另外一个优势是,由于降低了显存空间,模型可以使用更大的batch来提升吞吐。

核心技巧是改变矩阵的计算顺序。给定三个矩阵$A\in R^{m, k}$, $B \in R ^{k\times p}$, $C \in R ^{p\times n}$,通过矩阵乘法的结合律知道 $(A \times B) \times C = A \times (B \times C)$,但是二者的计算复杂度是不一样的。$(A \times B) \times C$的计算复杂度是$2\cdot m \cdot k \cdot p + 2\cdot m \cdot p \cdot n = 2\cdot m \cdot p \cdot(k+n)$ ,$A \times (B \times C)$的计算复杂度是 $2\cdot m \cdot k \cdot n + 2\cdot k \cdot p \cdot n = 2\cdot n \cdot k \cdot(m+p)$。当$n$相比 $m$ 和 $p$ 都显著更小的时候,第二种计算顺序的性能会远好于第一种。假设$m = k = p = 4096$, $n = 1$,那么第一种计算顺序的计算复杂度是$2 ⋅ 4096 ⋅ 4096 ⋅ 4097$,第二种方式的计算复杂度是$2 * 1 * 4096 * 8192$,显著低于第一种。

对于大语言模型的推理来说,我们也观察到了类似的模式。在推理场景的multi-head attention的计算中,模型的key embeddings和value embeddings的shape都是(batch, num_head, seq_len, dim),但是query embedding的shape是(batch, num_head,1, dim)。当序列长度达到4096这种量级的时候,query embedding是远小于key和value的embedding的,因此可以交换矩阵相乘的顺序,进而达到节省计算开销的目的。下面是方法的概要描述,在标准的multi-head attention实现中,quey、key、value的embedding是分别计算的,然后通过query embedding和key embedding来计算self-attention的权重矩阵,之后将这个权重矩阵和value embedding进行相乘得到最终的结果。但是,通过仔细分析发现这个顺序是可以交换的,即从query的embedding出发,一直向下进行计算,得到最终的结果。下面的给出了前后两种实现的计算顺序,关于整个方法的详细描述可以参加后文的细节分析部分。

High Level Idea of the Method

High Level Idea of the Method

KVCache:计算的天使,显存的恶魔

对于标准的multi-head attention计算来说,key和value embedding的计算是在线推理的主要开销,因为这部分计算和序列长度是线性相关的。在标准的未使用KVCache或者我们的方法加速的实现中,矩阵的计算复杂度是$2 * batch * seq\_len * hidden\_size * hidden\_size$。这个计算首先是和序列长度线性相关的,其次有一个非常大的常数,比如在GPT-3中,hidden_size的大小是12288。这部分的巨大计算开销是当前业界主要推理的方案都采用KVCache的主要原因。即将这两个矩阵的计算结果缓存到GPU显存中,然后在生成下一个token的时候,只结算对应token的embedding,然后和显存中存储的之前token的embedding拼接起来。这个方法的主要缺点是巨大的显存开销。在每一层的decoder,KVCache一共使用 $2 \cdot batch \cdot seq\_len \cdot hidden\_size$个参数。对于常见的大语言模型如ChatGPT来说,每一层decoder的模型参数是$12 * hidden\_size * hidden\_size$。这个意味着,如果 $(batch ⋅ seq\_len) > (6 ⋅ hidden\_size)$,那么KVCache的显存占用是比模型还要大的。当前的context长度需求已经达到了这个标准,因此KVCache是显存的主要瓶颈。

我们的方法

算法描述

下面按照query、key和value的处理将整个算法分成三个部分描述,每一部分重点阐述如何进行矩阵顺序的交换。为了使描述更为整洁清楚,下面使用$B$来代表模型输入的batch size,$T$代表序列长度,$N$代表多头注意力的头数,$D$代表模型的embedding大小,$H = N ⋅ D$代表模型的hidden size。同时我们使用query权重矩阵来代指multi-head attention中一开始的query embedding映射的权重矩阵,对于key和value也使用相同的描述。

  1. 第一步是计算query embedding,这部分计算和原生的实现没有区别,主要就是把当前token对应的hidden state乘以query权重,得到一个 $[B, N, D]$的query embedding矩阵。
  2. 第二部是计算query和key的注意力。在标准的multi-head attention实现中,首先通过将shape为 $(B, T, H)$的hidden states和shape为 $(H, H)$的key权重矩阵相乘,得到一个shape为$(B, T, H)$的矩阵,这个矩阵会进一步变换成为一个$(B, T, N, D)$的矩阵,用于把多头注意力的head这个维度独立出来处理。在KVCache的实现中,主要优化点是,在计算每一个token的时候,都把这个token的计算结果缓存到GPU显存中,再生成下一个token的时候,只需要计算下一个token自己的结果,所有之前的计算结果都是从显存中直接获取的。接下来是attention weights的计算,长度为$(B, T, N, D)$的矩阵会和第一步中的query矩阵相乘,得到attention weights。我们的方法对这个计算顺序进行了修改,先将第一步的query embedding和转置后的key权重矩阵直接相乘,得到了一个$(N, B, H)$的矩阵,然后再将这个矩阵转置后和shape为$(B, T, H)$的矩阵相乘。新的顺序里面,由于首先将比较小的query embedding参与计算,因此整体计算复杂度会明显降低。使用矩阵符号进行描述,假设$A$是当前token对应的hidden state,$B$是query权重,$C$是key的权重,$D$是全部的hidden states。那么标准的multi-head attention的实现思路是$(A × B) × (C × D)$,而我们的实现思路是$(A × B × C) × D$。需要额外说明的是,这部分的矩阵计算结果和原生的实现是完全一样的,因此原先的ALiBi和softmax的计算不会受到影响,也保证了整个attention weights的正确性。
  3. 第三步是根据value embedding和上述计算出的attention weights来计算最终的结果。和第二步一样通过交换顺序来提升性能。第二部计算出的query-key attention weights的shape是 $(B, N, T)$,我们首先将这个直接和shape为$(B, T, H)$的hidden states直接相乘,而不是像原生的方法一样用value权重和hidden states去相乘。这样就可以得到一个长度为 $(B, N, H)$的加权版本hidden states,之后将这个hidden states和value权重矩阵相乘得到最终的结果。同样的用矩阵符号描述的话,用$A$来代表第二步的attention weights,$B$代表整个hidden states,$C$代表value权重矩阵。标准的multi-head attention的计算顺序是$*A × (B × C)$,*而我们方法的计算顺序是$(A × B) × C$。