[cache refactor] Move all the caching logic to a per-layer approach (#39106)

* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077)

- Introduces CacheLayer and Cache base classes
- Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers
- Implements method/attr dispatch across layers to reduce boilerplate
- Adds CacheProcessor hooks for offloading, quantization, etc.
- Updates and passes tests

* fix quantized, add tests

* remove CacheProcessorList

* raushan review, arthur review

* joao review: minor things

* remove cache configs, make CacheLayer a mixin (joaos review)

* back to storage inside Cache()

* remove cachebase for decorator

* no more __getattr__

* fix tests

* joaos review except docs

* fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant`

More verbose exceptions in `fix_docstring` on docstring formatting issues.

* Revert "back to storage inside Cache()"

This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913.

* cyril review

* simplify cache export

* fix lfm2 cache

* HybridChunked to layer

* BC proxy object for cache.key_cache[i]=...

* reorder classes

* bfff come on LFM2

* better tests for hybrid and hybridChunked

* complete coverage for hybrid chunked caches (prefill chunking)

* reimplementing HybridChunked

* cyril review

* fix ci

* docs for cache refactor

* docs

* oopsie

* oopsie

* fix after merge

* cyril review

* arthur review

* opsie

* fix lfm2

* opsie2
This commit is contained in:
Manuel de Prada Corral
2025-07-22 16:10:25 +02:00
committed by GitHub
parent b16688e96a
commit c338fd43b0
64 changed files with 2779 additions and 2441 deletions

View File

@@ -1577,9 +1577,7 @@ class GenerationTesterMixin:
# 3. Check cache shapes
# 3.1. Encoder-Decoder checks
if config.is_encoder_decoder:
num_cache_decoder_layers = (
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@@ -1587,30 +1585,30 @@ class GenerationTesterMixin:
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
self_attention_layer_keys = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
)
self_attention_layer_value_cache = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
self_attention_layer_values = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
)
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
# Cross attention (ignore 3rd dim, see default shape preparation)
cross_attention_layer_key_cache = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
cross_attention_layer_keys = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
)
cross_attention_layer_value_cache = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
cross_attention_layer_values = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
)
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
num_cache_decoder_layers = len(past_kv)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@@ -1618,10 +1616,18 @@ class GenerationTesterMixin:
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
if is_legacy_cache:
self_attention_layer_keys = past_kv[i][0]
self_attention_layer_values = past_kv[i][1]
elif getattr(past_kv, "layers", None) is None:
# Cache is lot layered (i.e, Mamba derivatives)
self_attention_layer_keys = past_kv.key_cache[i]
self_attention_layer_values = past_kv.value_cache[i]
else:
self_attention_layer_keys = past_kv.layers[i].keys
self_attention_layer_values = past_kv.layers[i].values
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
@pytest.mark.generate
def test_generate_from_random_inputs_embeds(self):
@@ -1804,8 +1810,8 @@ class GenerationTesterMixin:
max_length = max_new_tokens + inputs_embeds.shape[1] - 1
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
self.assertIsInstance(outputs.past_key_values, StaticCache)
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
self.assertEqual(len(outputs.past_key_values), num_hidden_layers)
self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape)
@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
@@ -2027,8 +2033,8 @@ class GenerationTesterMixin:
num_hidden_layers = text_config.num_hidden_layers
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape)
# Check 2: The outputs must be similar to the case with dynamic cache
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
@@ -2629,12 +2635,12 @@ class GenerationTesterMixin:
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(decoder_past_key_values.key_cache),
[layer.keys.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_shape] * len(decoder_past_key_values.value_cache),
[layer.values.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
@@ -4040,13 +4046,13 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(isinstance(results.past_key_values, StaticCache))
# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
keys_0 = results.past_key_values.layers[0].keys
values_0 = results.past_key_values.layers[0].values
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
keys_1 = results.past_key_values.layers[1].keys
values_1 = results.past_key_values.layers[1].values
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
@pytest.mark.generate
@require_torch_multi_accelerator
@@ -4118,13 +4124,13 @@ class GenerationIntegrationTests(unittest.TestCase):
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
keys_0 = results.past_key_values.layers[0].keys
values_0 = results.past_key_values.layers[0].values
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
keys_1 = results.past_key_values.layers[1].keys
values_1 = results.past_key_values.layers[1].values
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
@slow
def test_padding_input_contrastive_search_gpt2(self):