[cache refactor] Move all the caching logic to a per-layer approach (#39106)
* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
This commit is contained in:
committed by
GitHub
parent
b16688e96a
commit
c338fd43b0
@@ -1577,9 +1577,7 @@ class GenerationTesterMixin:
|
||||
# 3. Check cache shapes
|
||||
# 3.1. Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
num_cache_decoder_layers = (
|
||||
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
|
||||
)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@@ -1587,30 +1585,30 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
|
||||
self_attention_layer_keys = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
|
||||
)
|
||||
self_attention_layer_value_cache = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
|
||||
self_attention_layer_values = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
|
||||
)
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
# Cross attention (ignore 3rd dim, see default shape preparation)
|
||||
cross_attention_layer_key_cache = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
|
||||
cross_attention_layer_keys = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
|
||||
)
|
||||
cross_attention_layer_value_cache = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
|
||||
cross_attention_layer_values = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
|
||||
)
|
||||
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
|
||||
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
|
||||
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
|
||||
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
|
||||
|
||||
# 3.2. Decoder-only checks
|
||||
else:
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
|
||||
num_cache_decoder_layers = len(past_kv)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@@ -1618,10 +1616,18 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
|
||||
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
if is_legacy_cache:
|
||||
self_attention_layer_keys = past_kv[i][0]
|
||||
self_attention_layer_values = past_kv[i][1]
|
||||
elif getattr(past_kv, "layers", None) is None:
|
||||
# Cache is lot layered (i.e, Mamba derivatives)
|
||||
self_attention_layer_keys = past_kv.key_cache[i]
|
||||
self_attention_layer_values = past_kv.value_cache[i]
|
||||
else:
|
||||
self_attention_layer_keys = past_kv.layers[i].keys
|
||||
self_attention_layer_values = past_kv.layers[i].values
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_from_random_inputs_embeds(self):
|
||||
@@ -1804,8 +1810,8 @@ class GenerationTesterMixin:
|
||||
max_length = max_new_tokens + inputs_embeds.shape[1] - 1
|
||||
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
|
||||
self.assertIsInstance(outputs.past_key_values, StaticCache)
|
||||
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
|
||||
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
|
||||
self.assertEqual(len(outputs.past_key_values), num_hidden_layers)
|
||||
self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
@@ -2027,8 +2033,8 @@ class GenerationTesterMixin:
|
||||
num_hidden_layers = text_config.num_hidden_layers
|
||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
|
||||
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
|
||||
self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers)
|
||||
self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape)
|
||||
|
||||
# Check 2: The outputs must be similar to the case with dynamic cache
|
||||
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
||||
@@ -2629,12 +2635,12 @@ class GenerationTesterMixin:
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||
[layer.keys.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||
[layer.values.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
|
||||
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
|
||||
@@ -4040,13 +4046,13 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||
|
||||
# check device of each layer
|
||||
key_cache_0 = results.past_key_values.key_cache[0]
|
||||
value_cache_0 = results.past_key_values.value_cache[0]
|
||||
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||
keys_0 = results.past_key_values.layers[0].keys
|
||||
values_0 = results.past_key_values.layers[0].values
|
||||
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
|
||||
|
||||
key_cache_1 = results.past_key_values.key_cache[1]
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
keys_1 = results.past_key_values.layers[1].keys
|
||||
values_1 = results.past_key_values.layers[1].values
|
||||
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_multi_accelerator
|
||||
@@ -4118,13 +4124,13 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||
|
||||
# check device of each layer
|
||||
key_cache_0 = results.past_key_values.key_cache[0]
|
||||
value_cache_0 = results.past_key_values.value_cache[0]
|
||||
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||
keys_0 = results.past_key_values.layers[0].keys
|
||||
values_0 = results.past_key_values.layers[0].values
|
||||
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
|
||||
|
||||
key_cache_1 = results.past_key_values.key_cache[1]
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
keys_1 = results.past_key_values.layers[1].keys
|
||||
values_1 = results.past_key_values.layers[1].values
|
||||
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
|
||||
|
||||
@slow
|
||||
def test_padding_input_contrastive_search_gpt2(self):
|
||||
|
||||
Reference in New Issue
Block a user