如果在本文中发现一些问题或者逻辑错误,欢迎随时评论或者通过[email protected]联系到我。
开始于:June 22 2023
发布于: June 23 2023
最后修改: June 23 2023
Other versions: English
本文主要通过线性代数的推导,证明了对于使用了ALiBi embedding的模型(如Bloom和MPT-Instruct),以及使用了 absolute position embeddings (如GPT模型),可以通过改变矩阵计算的顺序,在保持结果一致的情况下,相比原生的multi-head attention的计算复杂度显著下降。当前业界针对原生的attention计算的主流优化方式是通过KVCache,我们的方法相比KVCache的方法显存占用只有一半,同时计算性能差距非常小。
对于大语言模型来说,在序列长度达到4K这个量级之后,KVCache的显存占用已经比模型的参数量还要大,显存成了继续扩大context大小的主要瓶颈。我们的算法相比基于KVCache的multi-head attention只使用了一半的Cache显存空间,因此可以将模型的最大context大小接近提升一倍。另外一个优势是,由于降低了显存空间,模型可以使用更大的batch来提升吞吐。
核心技巧是改变矩阵的计算顺序。给定三个矩阵$A\in R^{m, k}$, $B \in R ^{k\times p}$, $C \in R ^{p\times n}$,通过矩阵乘法的结合律知道 $(A \times B) \times C = A \times (B \times C)$,但是二者的计算复杂度是不一样的。$(A \times B) \times C$的计算复杂度是$2\cdot m \cdot k \cdot p + 2\cdot m \cdot p \cdot n = 2\cdot m \cdot p \cdot(k+n)$ ,$A \times (B \times C)$的计算复杂度是 $2\cdot m \cdot k \cdot n + 2\cdot k \cdot p \cdot n = 2\cdot n \cdot k \cdot(m+p)$。当$n$相比 $m$ 和 $p$ 都显著更小的时候,第二种计算顺序的性能会远好于第一种。假设$m = k = p = 4096$, $n = 1$,那么第一种计算顺序的计算复杂度是$2 ⋅ 4096 ⋅ 4096 ⋅ 4097$,第二种方式的计算复杂度是$2 * 1 * 4096 * 8192$,显著低于第一种。
对于大语言模型的推理来说,我们也观察到了类似的模式。在推理场景的multi-head attention的计算中,模型的key embeddings和value embeddings的shape都是(batch, num_head, seq_len, dim),但是query embedding的shape是(batch, num_head,1, dim)。当序列长度达到4096这种量级的时候,query embedding是远小于key和value的embedding的,因此可以交换矩阵相乘的顺序,进而达到节省计算开销的目的。下面是方法的概要描述,在标准的multi-head attention实现中,quey、key、value的embedding是分别计算的,然后通过query embedding和key embedding来计算self-attention的权重矩阵,之后将这个权重矩阵和value embedding进行相乘得到最终的结果。但是,通过仔细分析发现这个顺序是可以交换的,即从query的embedding出发,一直向下进行计算,得到最终的结果。下面的给出了前后两种实现的计算顺序,关于整个方法的详细描述可以参加后文的细节分析部分。
High Level Idea of the Method
对于标准的multi-head attention计算来说,key和value embedding的计算是在线推理的主要开销,因为这部分计算和序列长度是线性相关的。在标准的未使用KVCache或者我们的方法加速的实现中,矩阵的计算复杂度是$2 * batch * seq\_len * hidden\_size * hidden\_size$。这个计算首先是和序列长度线性相关的,其次有一个非常大的常数,比如在GPT-3中,hidden_size的大小是12288。这部分的巨大计算开销是当前业界主要推理的方案都采用KVCache的主要原因。即将这两个矩阵的计算结果缓存到GPU显存中,然后在生成下一个token的时候,只结算对应token的embedding,然后和显存中存储的之前token的embedding拼接起来。这个方法的主要缺点是巨大的显存开销。在每一层的decoder,KVCache一共使用 $2 \cdot batch \cdot seq\_len \cdot hidden\_size$个参数。对于常见的大语言模型如ChatGPT来说,每一层decoder的模型参数是$12 * hidden\_size * hidden\_size$。这个意味着,如果 $(batch ⋅ seq\_len) > (6 ⋅ hidden\_size)$,那么KVCache的显存占用是比模型还要大的。当前的context长度需求已经达到了这个标准,因此KVCache是显存的主要瓶颈。
下面按照query、key和value的处理将整个算法分成三个部分描述,每一部分重点阐述如何进行矩阵顺序的交换。为了使描述更为整洁清楚,下面使用$B$来代表模型输入的batch size,$T$代表序列长度,$N$代表多头注意力的头数,$D$代表模型的embedding大小,$H = N ⋅ D$代表模型的hidden size。同时我们使用query权重矩阵来代指multi-head attention中一开始的query embedding映射的权重矩阵,对于key和value也使用相同的描述。