[cache refactor] Move all the caching logic to a per-layer approach (#39106)
* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
This commit is contained in:
committed by
GitHub
parent
b16688e96a
commit
c338fd43b0
@@ -82,22 +82,18 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s
|
||||
|
||||
## Cache storage implementation
|
||||
|
||||
The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`].
|
||||
Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape `[batch_size, num_heads, seq_len, head_dim]`.
|
||||
|
||||
Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWindowLayer`), which mostly changes how sequence length is handled and how the cache is updated.
|
||||
|
||||
In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`.
|
||||
- `key_cache`: A list of tensors, one for each layer.
|
||||
- `value_cache`: A list of tensors, one for each layer.
|
||||
The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token:
|
||||
|
||||
When new tokens are processed:
|
||||
|
||||
1. For each layer, the new key and value states are concatenated with the existing cache.
|
||||
```py
|
||||
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||||
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
|
||||
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)
|
||||
```
|
||||
|
||||
2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token.
|
||||
Other layer types like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added.
|
||||
|
||||
The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user