From 78079abeffaa97996aee1218a3b0c77f9d079d9a Mon Sep 17 00:00:00 2001 From: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com> Date: Mon, 26 May 2025 15:53:41 +0200 Subject: [PATCH] Improved cache docs (#38060) * improved cache docs Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 4 +- docs/source/en/cache_explanation.md | 74 +++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 9e03a234b7..8d4abf1337 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -76,12 +76,12 @@ title: Prompt engineering - local: llm_optims title: Optimizing inference + - local: cache_explanation + title: Caching - local: kv_cache title: KV cache strategies - local: serving title: Serving - - local: cache_explanation - title: Caching - local: llm_tutorial_optimization title: Getting the most out of LLMs - local: perplexity diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 59496e4298..0ccf612d21 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -15,8 +15,7 @@ rendered properly in your Markdown viewer. --> # Caching - -Imagine you’re having a conversation with someone, and instead of remembering what they previously said, they have to start from scratch every time you respond. This would be slow and inefficient, right? +Imagine you're having a conversation with someone, and instead of remembering what they previously said, they have to start from scratch every time you respond. This would be slow and inefficient, right? You can extend this analogy to transformer models. Autoregressive model generation can be slow because it makes a prediction one token at a time. Each new prediction is dependent on all the previous context. @@ -29,8 +28,50 @@ A key-value (KV) cache eliminates this inefficiency by storing kv pairs derived > [!WARNING] > Caching should only be used for **inference**. It may cause unexpected errors if it's enabled during training. +To better understand how and why caching works, let's take a closer look at the structure of the attention matrices. + +## Attention matrices + +The **scaled dot-product attention** is calculated as shown below for a batch of size `b`, number of attention heads `h`, sequence length so far `T`, and dimension per attention head `d_head`. + +$$ +\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_{\text{head}}}} \times \text{mask} \right) V +$$ + +The query (`Q`), key (`K`), and value (`V`) matrices are projections from the input embeddings of shape `(b, h, T, d_head)`. + +For causal attention, the mask prevents the model from attending to future tokens. Once a token is processed, its representation never changes with respect to future tokens, which means \\( K_{\text{past}} \\) and \\( V_{\text{past}} \\) can be cached and reused to compute the last token's representation. + +$$ +\text{Attention}(q_t, [\underbrace{k_1, k_2, \dots, k_{t-1}}_{\text{cached}}, k_{t}], [\underbrace{v_1, v_2, \dots, v_{t-1}}_{\text{cached}}, v_{t}]) +$$ + +At inference time, you only need the last token's query to compute the representation \\( x_t \\) that predicts the next token \\( t+1 \\). At each step, the new key and value vectors are **stored** in the cache and **appended** to the past keys and values. + +$$ +K_{\text{cache}} \leftarrow \text{concat}(K_{\text{past}}, k_t), \quad V_{\text{cache}} \leftarrow \text{concat}(V_{\text{past}}, v_t) +$$ + +Attention is calculated independently in each layer of the model, and caching is done on a per-layer basis. + +Refer to the table below to compare how caching improves efficiency. + +| without caching | with caching | | | | +|---|---|---|---|---| +| for each step, recompute all previous `K` and `V` | for each step, only compute current `K` and `V` | | | | +| attention cost per step is **quadratic** with sequence length | attention cost per step is **linear** with sequence length (memory grows linearly, but compute/token remains low) | | | | + + + ## Cache class +A basic KV cache interface takes a key and value tensor for the current token and returns the updated `K` and `V` tensors. This is internally managed by a model's `forward` method. + +```py +new_K, new_V = cache.update(k_t, v_t, layer_idx) +attn_output = attn_layer_idx_fn(q_t, new_K, new_V) +``` + When you use Transformers' [`Cache`] class, the self-attention module performs several critical steps to integrate past and present information. 1. The attention module concatenates current kv pairs with past kv pairs stored in the cache. This creates attentions weights with the shape `(new_tokens_length, past_kv_length + new_tokens_length)`. The current and past kv pairs are essentially combined to compute the attention scores, ensuring a model is aware of previous context and the current input. @@ -39,6 +80,27 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s 3. It is also important to be aware of the `cache_position`. This is important if you want to reuse a prefilled [`Cache`] with the `forward` method because you have to pass a valid `cache_position` value. This indicates the input positions in a sequence. `cache_position` is unaffected by padding, and it always adds one more position for each token. For example, if a kv cache contains 10 tokens - regardless of pad tokens - the cache position for the next token should be `torch.tensor([10])`. +## Cache storage implementation + +The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`]. + + +In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`. +- `key_cache`: A list of tensors, one for each layer. +- `value_cache`: A list of tensors, one for each layer. + +When new tokens are processed: + +1. For each layer, the new key and value states are concatenated with the existing cache. +```py +self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) +self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) +``` + +2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token. + +3. The cache maintains a count of seen tokens through `self._seen_tokens`. This is updated when the first layer processes a new token. + The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token. ```py @@ -72,10 +134,14 @@ for _ in range(max_new_tokens): print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]) "[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA," ``` - ## Legacy cache format -Before the [`Cache`] class, the cache used to be stored as a tuple of tuples of tensors. This format has is dynamic because it grows as text is generated, similar to [`DynamicCache`]. +Before the [`Cache`] class, the cache used to be stored as a tuple of tuples of tensors. This format is dynamic because it grows as text is generated, similar to [`DynamicCache`]. + +The legacy format is essentially the same data structure but organized differently. +- It's a tuple of tuples, where each inner tuple contains the key and value tensors for a layer. +- The tensors have the same shape `[batch_size, num_heads, seq_len, head_dim]`. +- The format is less flexible and doesn't support features like quantization or offloading. If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~DynamicCache.from_legacy_cache`] and [`DynamicCache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format.