If you find any mistakes in this article, please feel free to contact me through comment or [email protected].
Started writing on June 22 2023
Released on June 23 2023
Last updated on June 23 2023
Other versions: 中文
In this article, we argue that the KVCache can be removed for language models using ALiBi embedding (like Bloom, MPT-Instruct, Baichuan2-13B) and absolute position embeddings (like GPT), with the same result and moderate inference performance drop. To be more precise, the KVCache can be replaced with a smaller HiddenStatesCache, which only saves the hidden sates of the previous tokens and uses only half memory of the KVCache. This brings two advantages for LLM inference. First, as the KVCache memory size is the main bottleneck for very long context, this method can almost double the maximum context length GPU memory can support. Second, using smaller memory means the batch size can be further increased to improve throughput.
Our method uses half memory size of KVCache multi-head attention with very low inference performance drop (5% in my tests). This is supported by theoretical analysis and initial experimental result, and we think it is acceptable considering the large memory saving.
The main trick is a careful manipulation of the matrix multiplication order. Given three matrices
$A\in R^{m, k}$, $B \in R ^{k\times p}$, $C \in R ^{p\times n}$, association of matrix multiplication gives $(A \times B) \times C = A \times (B \times C)$. The computational cost of $(A \times B) \times C$ is $2\cdot m \cdot k \cdot p + 2\cdot m \cdot p \cdot n = 2\cdot m \cdot p \cdot(k+n)$ , while the computational cost of $A \times (B \times C)$ is $2\cdot m \cdot k \cdot n + 2\cdot k \cdot p \cdot n = 2\cdot n \cdot k \cdot(m+p)$. Suppose $n$ is significantly smaller than $m$ and $p$, the second computation is significantly faster than the first one. Assume $m = k = p = 4096$, $n = 1$, then the fist one will be $2 ⋅ 4096 ⋅ 4096 ⋅ 4097$, while the second one will be $2 * 1 * 4096 * 8192$ which is much smaller.
For LLM inference, we observe such optimization opportunity as the key embeddings and value embeddings have shape (batch, num_head, seq_len, dim), while the query embeddings have shape (batch, num_head,1, dim). When sequence length is as large as 4096, query embedding is far smaller than key and value embeddings, and we can change the order of matrix multiplication to save computational cost. The following is the high level idea of the method. In the standard multi-head attention implementation, the query, key, and value embeddings are calculated separately, and then perform query-key self-attention, the resulted attention weights are multiplied with value embedding to get the final output. However, we can start the whole computation sequence starting from query, which has a much smaller sequence dimension and compuational cost. The following is the computaion sequence before and after our method.
High Level Idea of the Method
The calculation of key and vector embedding is the largest computation during inference, which has computational complexity $2 * batch * seq\_len * hidden\_size * hidden\_size$. This computation is linear with the sequence lenth, and has very large constants $2 * hidden\_size * hidden\_size$, where $hidden\_size = 12288$ for GPT-3. This computation complexity is the reason why we introduce KVCache to store the key and value embeddings of previous tokens. However, this comes with very high storage cost. The KVCache takes $2 \cdot batch \cdot seq\_len \cdot hidden\_size$ number of parameters. Large language models like ChatGPT has in total $12 * hidden\_size * hidden\_size$ parameters for each layer approximately. This means, if $(batch ⋅ seq\_len) > (6 ⋅ hidden\_size)$, the KVCache uses more memory than the model itself. For smaller models like LLaMA-13B, the hidden_size is $5120$, so KVCache easily takes more space than the model parameters when the context size is as large as $4096$.
We describe the computation in three steps.
For ease of description, we use $B$ for batch size, $T$ for sequence length, $N$ for number of heads, and $D$ for dimension, the hidden size $H = N ⋅ D$. We use the term query weights to denote the weight in query embedding projection, key and value weights are samely defined.