[cache refactor] Move all the caching logic to a per-layer approach (#39106)

* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077)

- Introduces CacheLayer and Cache base classes
- Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers
- Implements method/attr dispatch across layers to reduce boilerplate
- Adds CacheProcessor hooks for offloading, quantization, etc.
- Updates and passes tests

* fix quantized, add tests

* remove CacheProcessorList

* raushan review, arthur review

* joao review: minor things

* remove cache configs, make CacheLayer a mixin (joaos review)

* back to storage inside Cache()

* remove cachebase for decorator

* no more __getattr__

* fix tests

* joaos review except docs

* fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant`

More verbose exceptions in `fix_docstring` on docstring formatting issues.

* Revert "back to storage inside Cache()"

This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913.

* cyril review

* simplify cache export

* fix lfm2 cache

* HybridChunked to layer

* BC proxy object for cache.key_cache[i]=...

* reorder classes

* bfff come on LFM2

* better tests for hybrid and hybridChunked

* complete coverage for hybrid chunked caches (prefill chunking)

* reimplementing HybridChunked

* cyril review

* fix ci

* docs for cache refactor

* docs

* oopsie

* oopsie

* fix after merge

* cyril review

* arthur review

* opsie

* fix lfm2

* opsie2
This commit is contained in:
Manuel de Prada Corral
2025-07-22 16:10:25 +02:00
committed by GitHub
parent b16688e96a
commit c338fd43b0
64 changed files with 2779 additions and 2441 deletions

View File

@@ -168,14 +168,9 @@ class DeepseekV2ModelTest(CausalLMModelTest, unittest.TestCase):
expected_value_shape = expected_common_shape + (config.v_head_dim,)
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_key_shape] * len(decoder_past_key_values.key_cache),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_value_shape] * len(decoder_past_key_values.value_cache),
)
for layer in decoder_past_key_values.layers:
self.assertEqual(layer.keys.shape, expected_key_shape)
self.assertEqual(layer.values.shape, expected_value_shape)
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
def test_generate_compilation_all_outputs(self):

View File

@@ -440,13 +440,11 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# difference: last dim
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
v_embed_dim = config.v_head_dim
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
# build the full cache shapes
num_hidden_layers = config.num_hidden_layers
all_cache_shapes = [
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
]
all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)]
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
@require_torch_large_accelerator

View File

@@ -399,12 +399,12 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(decoder_past_key_values.key_cache),
[layer.keys.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_shape] * len(decoder_past_key_values.value_cache),
[layer.values.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
def _check_scores(self, batch_size, scores, generated_length, config):

View File

@@ -38,7 +38,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model
from transformers import AutoTokenizer, Cache, FalconH1ForCausalLM, FalconH1Model
from transformers.models.falcon_h1.modeling_falcon_h1 import (
FalconHybridMambaAttentionDynamicCache,
)
@@ -272,6 +272,43 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
{"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {}
)
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
# (batch, head, seq_length, head_features)
expected_shape = (
batch_size,
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
cache_length,
config.hidden_size // config.num_attention_heads,
)
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(decoder_past_key_values.key_cache),
)
self.assertListEqual(
[value_cache.shape for value_cache in decoder_past_key_values.value_cache],
[expected_shape] * len(decoder_past_key_values.value_cache),
)
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
else:
self.assertListEqual(
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
[True] * len(decoder_past_key_values),
)
# check shape key, value
self.assertListEqual(
[layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values],
[expected_shape] * len(decoder_past_key_values),
)
self.assertListEqual(
[layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values],
[expected_shape] * len(decoder_past_key_values),
)
def setUp(self):
self.model_tester = FalconH1ModelTester(self)
self.config_tester = ConfigTester(self, config_class=FalconH1Config, hidden_size=64)

View File

@@ -235,8 +235,8 @@ class GPTNeoXModelTester:
"""Deep copy a DynamicCache to reuse the same one multiple times."""
new_cache = cache
for i in range(len(cache)):
new_cache.key_cache[i] = cache.key_cache[i].clone()
new_cache.value_cache[i] = cache.value_cache[i].clone()
new_cache.layers[i].keys = cache.layers[i].keys.clone()
new_cache.layers[i].values = cache.layers[i].values.clone()
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
# We need to run both on a copy of the cache, otherwise it is modified in-place

View File

@@ -271,7 +271,7 @@ class T5GemmaModelTester:
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertIsNotNone(decoder_past)
self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers)
self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers)
self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers)
def check_prepare_lm_labels_via_shift_left(
self,
@@ -1060,9 +1060,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# 3. Check cache shapes
# 3.1. Encoder-Decoder checks
if config.is_encoder_decoder:
num_cache_decoder_layers = (
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@@ -1070,30 +1068,30 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
self_attention_layer_keys = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
)
self_attention_layer_value_cache = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
self_attention_layer_values = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
)
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
# Cross attention (ignore 3rd dim, see default shape preparation)
cross_attention_layer_key_cache = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
cross_attention_layer_keys = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
)
cross_attention_layer_value_cache = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
cross_attention_layer_values = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
)
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@@ -1101,10 +1099,10 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
@unittest.skip("Mismatch issue doesn't exist in T5Gemma.")
def test_load_with_mismatched_shapes(self):