[cache refactor] Move all the caching logic to a per-layer approach (#39106)

* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077)

- Introduces CacheLayer and Cache base classes
- Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers
- Implements method/attr dispatch across layers to reduce boilerplate
- Adds CacheProcessor hooks for offloading, quantization, etc.
- Updates and passes tests

* fix quantized, add tests

* remove CacheProcessorList

* raushan review, arthur review

* joao review: minor things

* remove cache configs, make CacheLayer a mixin (joaos review)

* back to storage inside Cache()

* remove cachebase for decorator

* no more __getattr__

* fix tests

* joaos review except docs

* fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant`

More verbose exceptions in `fix_docstring` on docstring formatting issues.

* Revert "back to storage inside Cache()"

This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913.

* cyril review

* simplify cache export

* fix lfm2 cache

* HybridChunked to layer

* BC proxy object for cache.key_cache[i]=...

* reorder classes

* bfff come on LFM2

* better tests for hybrid and hybridChunked

* complete coverage for hybrid chunked caches (prefill chunking)

* reimplementing HybridChunked

* cyril review

* fix ci

* docs for cache refactor

* docs

* oopsie

* oopsie

* fix after merge

* cyril review

* arthur review

* opsie

* fix lfm2

* opsie2
This commit is contained in:
Manuel de Prada Corral
2025-07-22 16:10:25 +02:00
committed by GitHub
parent b16688e96a
commit c338fd43b0
64 changed files with 2779 additions and 2441 deletions

View File

@@ -271,7 +271,7 @@ class T5GemmaModelTester:
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertIsNotNone(decoder_past)
self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers)
self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers)
self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers)
def check_prepare_lm_labels_via_shift_left(
self,
@@ -1060,9 +1060,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# 3. Check cache shapes
# 3.1. Encoder-Decoder checks
if config.is_encoder_decoder:
num_cache_decoder_layers = (
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@@ -1070,30 +1068,30 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
self_attention_layer_keys = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
)
self_attention_layer_value_cache = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
self_attention_layer_values = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
)
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
# Cross attention (ignore 3rd dim, see default shape preparation)
cross_attention_layer_key_cache = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
cross_attention_layer_keys = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
)
cross_attention_layer_value_cache = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
cross_attention_layer_values = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
)
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@@ -1101,10 +1099,10 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
@unittest.skip("Mismatch issue doesn't exist in T5Gemma.")
def test_load_with_mismatched_shapes(self):