[cache refactor] Move all the caching logic to a per-layer approach (#39106)

* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077)

- Introduces CacheLayer and Cache base classes
- Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers
- Implements method/attr dispatch across layers to reduce boilerplate
- Adds CacheProcessor hooks for offloading, quantization, etc.
- Updates and passes tests

* fix quantized, add tests

* remove CacheProcessorList

* raushan review, arthur review

* joao review: minor things

* remove cache configs, make CacheLayer a mixin (joaos review)

* back to storage inside Cache()

* remove cachebase for decorator

* no more __getattr__

* fix tests

* joaos review except docs

* fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant`

More verbose exceptions in `fix_docstring` on docstring formatting issues.

* Revert "back to storage inside Cache()"

This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913.

* cyril review

* simplify cache export

* fix lfm2 cache

* HybridChunked to layer

* BC proxy object for cache.key_cache[i]=...

* reorder classes

* bfff come on LFM2

* better tests for hybrid and hybridChunked

* complete coverage for hybrid chunked caches (prefill chunking)

* reimplementing HybridChunked

* cyril review

* fix ci

* docs for cache refactor

* docs

* oopsie

* oopsie

* fix after merge

* cyril review

* arthur review

* opsie

* fix lfm2

* opsie2
This commit is contained in:
Manuel de Prada Corral
2025-07-22 16:10:25 +02:00
committed by GitHub
parent b16688e96a
commit c338fd43b0
64 changed files with 2779 additions and 2441 deletions

View File

@@ -82,22 +82,18 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s
## Cache storage implementation
The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`].
Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape `[batch_size, num_heads, seq_len, head_dim]`.
Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWindowLayer`), which mostly changes how sequence length is handled and how the cache is updated.
In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`.
- `key_cache`: A list of tensors, one for each layer.
- `value_cache`: A list of tensors, one for each layer.
The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token:
When new tokens are processed:
1. For each layer, the new key and value states are concatenated with the existing cache.
```py
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)
```
2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token.
Other layer types like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added.
The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token.

View File

@@ -356,66 +356,93 @@ A [`Constraint`] can be used to force the generation to include specific tokens
## Caches
[[autodoc]] Cache
- update
[[autodoc]] CacheConfig
- update
[[autodoc]] QuantizedCacheConfig
- validate
[[autodoc]] DynamicCache
[[autodoc]] CacheLayerMixin
- update
- get_seq_length
- get_mask_sizes
- get_max_cache_shape
- reset
- reorder_cache
[[autodoc]] DynamicLayer
- update
- crop
- batch_repeat_interleave
- batch_select_indices
[[autodoc]] StaticLayer
- update
[[autodoc]] SlidingWindowLayer
- update
[[autodoc]] CacheProcessor
- pre_update
- post_update
[[autodoc]] OffloadedCacheProcessor
- pre_update
[[autodoc]] QuantizedCacheProcessor
- post_update
[[autodoc]] QuantoQuantizedCacheProcessor
- post_update
[[autodoc]] HQQQuantizedCacheProcessor
- post_update
[[autodoc]] Cache
- update
- get_seq_length
- get_mask_sizes
- get_max_cache_shape
- reset
- reorder_cache
- crop
- batch_repeat_interleave
- batch_select_indices
[[autodoc]] DynamicCache
- to_legacy_cache
- from_legacy_cache
[[autodoc]] QuantizedCache
- update
- get_seq_length
[[autodoc]] QuantoQuantizedCache
[[autodoc]] QuantoQuantizedCacheProcessor
[[autodoc]] HQQQuantizedCache
[[autodoc]] HQQQuantizedCacheProcessor
[[autodoc]] OffloadedCache
- update
- prefetch_layer
- evict_previous_layer
[[autodoc]] StaticCache
- update
- get_seq_length
- reset
[[autodoc]] OffloadedStaticCache
- update
- get_seq_length
- reset
[[autodoc]] HybridCache
- update
- get_seq_length
- reset
[[autodoc]] HybridChunkedCache
[[autodoc]] SlidingWindowCache
- update
- reset
[[autodoc]] EncoderDecoderCache
- get_seq_length
- to_legacy_cache
- from_legacy_cache
- reset
- reorder_cache
[[autodoc]] MambaCache
- update_conv_state
- update_ssm_state
- reset
[[autodoc]] CacheConfig
[[autodoc]] QuantizedCacheConfig
## Watermark Utils
[[autodoc]] WatermarkingConfig

View File

@@ -134,7 +134,7 @@ The [`QuantizedCache`] reduces memory requirements by quantizing the KV values t
> [!WARNING]
> Quantizing the cache can harm latency if the context length is short and there is enough GPU memory available for generation without enabling cache quantization. Try to find a balance between memory efficiency and latency.
Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and indicate the quantization backend in [`QuantizedCacheConfig`]. Any additional quantization related parameters should also be passed either as a dict or an instance of [`QuantizedCacheConfig`]. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length.
Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and the quantization backend, as well as any additional quantization related parameters should also be passed either as a dict. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length.
<hfoptions id="quantized-cache">
<hfoption id="HQQQuantizedCache">
@@ -143,7 +143,7 @@ For [`HQQQuantizedCache`], we recommend setting the `axis-key` and `axis-value`
```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
@@ -161,7 +161,7 @@ For [`QuantoQuantizedCache`], we recommend setting the `axis-key` and `axis-valu
```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
@@ -275,7 +275,6 @@ from transformers.cache_utils import (
StaticCache,
SlidingWindowCache,
QuantoQuantizedCache,
QuantizedCacheConfig,
)
model_id = "meta-llama/Llama-2-7b-chat-hf"

View File

@@ -345,12 +345,6 @@ generation_output[:2]
[[autodoc]] Cache
- update
[[autodoc]] CacheConfig
- update
[[autodoc]] QuantizedCacheConfig
- validate
[[autodoc]] DynamicCache
- update
- get_seq_length