[cache refactor] Move all the caching logic to a per-layer approach (#39106)
* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
This commit is contained in:
committed by
GitHub
parent
b16688e96a
commit
c338fd43b0
@@ -168,14 +168,9 @@ class DeepseekV2ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
expected_value_shape = expected_common_shape + (config.v_head_dim,)
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_key_shape] * len(decoder_past_key_values.key_cache),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_value_shape] * len(decoder_past_key_values.value_cache),
|
||||
)
|
||||
for layer in decoder_past_key_values.layers:
|
||||
self.assertEqual(layer.keys.shape, expected_key_shape)
|
||||
self.assertEqual(layer.values.shape, expected_value_shape)
|
||||
|
||||
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
|
||||
def test_generate_compilation_all_outputs(self):
|
||||
|
||||
@@ -440,13 +440,11 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
# difference: last dim
|
||||
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
v_embed_dim = config.v_head_dim
|
||||
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
|
||||
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
|
||||
self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
|
||||
self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
|
||||
# build the full cache shapes
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
all_cache_shapes = [
|
||||
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
|
||||
]
|
||||
all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)]
|
||||
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
|
||||
|
||||
@require_torch_large_accelerator
|
||||
|
||||
@@ -399,12 +399,12 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||
[layer.keys.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||
[layer.values.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
|
||||
def _check_scores(self, batch_size, scores, generated_length, config):
|
||||
|
||||
@@ -38,7 +38,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model
|
||||
from transformers import AutoTokenizer, Cache, FalconH1ForCausalLM, FalconH1Model
|
||||
from transformers.models.falcon_h1.modeling_falcon_h1 import (
|
||||
FalconHybridMambaAttentionDynamicCache,
|
||||
)
|
||||
@@ -272,6 +272,43 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
{"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {}
|
||||
)
|
||||
|
||||
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
||||
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
|
||||
|
||||
# (batch, head, seq_length, head_features)
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
||||
cache_length,
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_cache.shape for value_cache in decoder_past_key_values.value_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||
)
|
||||
|
||||
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
|
||||
else:
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
|
||||
[True] * len(decoder_past_key_values),
|
||||
)
|
||||
# check shape key, value
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values],
|
||||
[expected_shape] * len(decoder_past_key_values),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values],
|
||||
[expected_shape] * len(decoder_past_key_values),
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FalconH1ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=FalconH1Config, hidden_size=64)
|
||||
|
||||
@@ -235,8 +235,8 @@ class GPTNeoXModelTester:
|
||||
"""Deep copy a DynamicCache to reuse the same one multiple times."""
|
||||
new_cache = cache
|
||||
for i in range(len(cache)):
|
||||
new_cache.key_cache[i] = cache.key_cache[i].clone()
|
||||
new_cache.value_cache[i] = cache.value_cache[i].clone()
|
||||
new_cache.layers[i].keys = cache.layers[i].keys.clone()
|
||||
new_cache.layers[i].values = cache.layers[i].values.clone()
|
||||
|
||||
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
|
||||
# We need to run both on a copy of the cache, otherwise it is modified in-place
|
||||
|
||||
@@ -271,7 +271,7 @@ class T5GemmaModelTester:
|
||||
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertIsNotNone(decoder_past)
|
||||
self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers)
|
||||
self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers)
|
||||
self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers)
|
||||
|
||||
def check_prepare_lm_labels_via_shift_left(
|
||||
self,
|
||||
@@ -1060,9 +1060,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
# 3. Check cache shapes
|
||||
# 3.1. Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
num_cache_decoder_layers = (
|
||||
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
|
||||
)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@@ -1070,30 +1068,30 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
|
||||
self_attention_layer_keys = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
|
||||
)
|
||||
self_attention_layer_value_cache = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
|
||||
self_attention_layer_values = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
|
||||
)
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
# Cross attention (ignore 3rd dim, see default shape preparation)
|
||||
cross_attention_layer_key_cache = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
|
||||
cross_attention_layer_keys = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
|
||||
)
|
||||
cross_attention_layer_value_cache = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
|
||||
cross_attention_layer_values = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
|
||||
)
|
||||
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
|
||||
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
|
||||
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
|
||||
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
|
||||
|
||||
# 3.2. Decoder-only checks
|
||||
else:
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@@ -1101,10 +1099,10 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
|
||||
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
|
||||
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
@unittest.skip("Mismatch issue doesn't exist in T5Gemma.")
|
||||
def test_load_with_mismatched_shapes(self):
|
||||
|
||||
Reference in New Issue
Block a user