[cache refactor] Move all the caching logic to a per-layer approach (#39106)
* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
This commit is contained in:
committed by
GitHub
parent
b16688e96a
commit
c338fd43b0
@@ -271,7 +271,7 @@ class T5GemmaModelTester:
|
||||
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertIsNotNone(decoder_past)
|
||||
self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers)
|
||||
self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers)
|
||||
self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers)
|
||||
|
||||
def check_prepare_lm_labels_via_shift_left(
|
||||
self,
|
||||
@@ -1060,9 +1060,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
# 3. Check cache shapes
|
||||
# 3.1. Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
num_cache_decoder_layers = (
|
||||
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
|
||||
)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@@ -1070,30 +1068,30 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
|
||||
self_attention_layer_keys = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
|
||||
)
|
||||
self_attention_layer_value_cache = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
|
||||
self_attention_layer_values = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
|
||||
)
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
# Cross attention (ignore 3rd dim, see default shape preparation)
|
||||
cross_attention_layer_key_cache = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
|
||||
cross_attention_layer_keys = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
|
||||
)
|
||||
cross_attention_layer_value_cache = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
|
||||
cross_attention_layer_values = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
|
||||
)
|
||||
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
|
||||
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
|
||||
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
|
||||
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
|
||||
|
||||
# 3.2. Decoder-only checks
|
||||
else:
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@@ -1101,10 +1099,10 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
|
||||
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
|
||||
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
@unittest.skip("Mismatch issue doesn't exist in T5Gemma.")
|
||||
def test_load_with_mismatched_shapes(self):
|
||||
|
||||
Reference in New Issue
Block a user