[cache refactor] Move all the caching logic to a per-layer approach (#39106)
* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
This commit is contained in:
committed by
GitHub
parent
b16688e96a
commit
c338fd43b0
@@ -1577,9 +1577,7 @@ class GenerationTesterMixin:
|
||||
# 3. Check cache shapes
|
||||
# 3.1. Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
num_cache_decoder_layers = (
|
||||
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
|
||||
)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@@ -1587,30 +1585,30 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
|
||||
self_attention_layer_keys = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
|
||||
)
|
||||
self_attention_layer_value_cache = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
|
||||
self_attention_layer_values = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
|
||||
)
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
# Cross attention (ignore 3rd dim, see default shape preparation)
|
||||
cross_attention_layer_key_cache = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
|
||||
cross_attention_layer_keys = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
|
||||
)
|
||||
cross_attention_layer_value_cache = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
|
||||
cross_attention_layer_values = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
|
||||
)
|
||||
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
|
||||
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
|
||||
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
|
||||
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
|
||||
|
||||
# 3.2. Decoder-only checks
|
||||
else:
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
|
||||
num_cache_decoder_layers = len(past_kv)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@@ -1618,10 +1616,18 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
|
||||
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
if is_legacy_cache:
|
||||
self_attention_layer_keys = past_kv[i][0]
|
||||
self_attention_layer_values = past_kv[i][1]
|
||||
elif getattr(past_kv, "layers", None) is None:
|
||||
# Cache is lot layered (i.e, Mamba derivatives)
|
||||
self_attention_layer_keys = past_kv.key_cache[i]
|
||||
self_attention_layer_values = past_kv.value_cache[i]
|
||||
else:
|
||||
self_attention_layer_keys = past_kv.layers[i].keys
|
||||
self_attention_layer_values = past_kv.layers[i].values
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_from_random_inputs_embeds(self):
|
||||
@@ -1804,8 +1810,8 @@ class GenerationTesterMixin:
|
||||
max_length = max_new_tokens + inputs_embeds.shape[1] - 1
|
||||
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
|
||||
self.assertIsInstance(outputs.past_key_values, StaticCache)
|
||||
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
|
||||
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
|
||||
self.assertEqual(len(outputs.past_key_values), num_hidden_layers)
|
||||
self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
@@ -2027,8 +2033,8 @@ class GenerationTesterMixin:
|
||||
num_hidden_layers = text_config.num_hidden_layers
|
||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
|
||||
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
|
||||
self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers)
|
||||
self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape)
|
||||
|
||||
# Check 2: The outputs must be similar to the case with dynamic cache
|
||||
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
||||
@@ -2629,12 +2635,12 @@ class GenerationTesterMixin:
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||
[layer.keys.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||
[layer.values.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
|
||||
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
|
||||
@@ -4040,13 +4046,13 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||
|
||||
# check device of each layer
|
||||
key_cache_0 = results.past_key_values.key_cache[0]
|
||||
value_cache_0 = results.past_key_values.value_cache[0]
|
||||
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||
keys_0 = results.past_key_values.layers[0].keys
|
||||
values_0 = results.past_key_values.layers[0].values
|
||||
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
|
||||
|
||||
key_cache_1 = results.past_key_values.key_cache[1]
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
keys_1 = results.past_key_values.layers[1].keys
|
||||
values_1 = results.past_key_values.layers[1].values
|
||||
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_multi_accelerator
|
||||
@@ -4118,13 +4124,13 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||
|
||||
# check device of each layer
|
||||
key_cache_0 = results.past_key_values.key_cache[0]
|
||||
value_cache_0 = results.past_key_values.value_cache[0]
|
||||
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||
keys_0 = results.past_key_values.layers[0].keys
|
||||
values_0 = results.past_key_values.layers[0].values
|
||||
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
|
||||
|
||||
key_cache_1 = results.past_key_values.key_cache[1]
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
keys_1 = results.past_key_values.layers[1].keys
|
||||
values_1 = results.past_key_values.layers[1].values
|
||||
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
|
||||
|
||||
@slow
|
||||
def test_padding_input_contrastive_search_gpt2(self):
|
||||
|
||||
@@ -168,14 +168,9 @@ class DeepseekV2ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
expected_value_shape = expected_common_shape + (config.v_head_dim,)
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_key_shape] * len(decoder_past_key_values.key_cache),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_value_shape] * len(decoder_past_key_values.value_cache),
|
||||
)
|
||||
for layer in decoder_past_key_values.layers:
|
||||
self.assertEqual(layer.keys.shape, expected_key_shape)
|
||||
self.assertEqual(layer.values.shape, expected_value_shape)
|
||||
|
||||
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
|
||||
def test_generate_compilation_all_outputs(self):
|
||||
|
||||
@@ -440,13 +440,11 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
# difference: last dim
|
||||
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
v_embed_dim = config.v_head_dim
|
||||
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
|
||||
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
|
||||
self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
|
||||
self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
|
||||
# build the full cache shapes
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
all_cache_shapes = [
|
||||
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
|
||||
]
|
||||
all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)]
|
||||
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
|
||||
|
||||
@require_torch_large_accelerator
|
||||
|
||||
@@ -399,12 +399,12 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||
[layer.keys.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||
[layer.values.shape for layer in decoder_past_key_values.layers],
|
||||
[expected_shape] * len(decoder_past_key_values.layers),
|
||||
)
|
||||
|
||||
def _check_scores(self, batch_size, scores, generated_length, config):
|
||||
|
||||
@@ -38,7 +38,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model
|
||||
from transformers import AutoTokenizer, Cache, FalconH1ForCausalLM, FalconH1Model
|
||||
from transformers.models.falcon_h1.modeling_falcon_h1 import (
|
||||
FalconHybridMambaAttentionDynamicCache,
|
||||
)
|
||||
@@ -272,6 +272,43 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
{"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {}
|
||||
)
|
||||
|
||||
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
||||
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
|
||||
|
||||
# (batch, head, seq_length, head_features)
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
||||
cache_length,
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_cache.shape for value_cache in decoder_past_key_values.value_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||
)
|
||||
|
||||
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
|
||||
else:
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
|
||||
[True] * len(decoder_past_key_values),
|
||||
)
|
||||
# check shape key, value
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values],
|
||||
[expected_shape] * len(decoder_past_key_values),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values],
|
||||
[expected_shape] * len(decoder_past_key_values),
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FalconH1ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=FalconH1Config, hidden_size=64)
|
||||
|
||||
@@ -235,8 +235,8 @@ class GPTNeoXModelTester:
|
||||
"""Deep copy a DynamicCache to reuse the same one multiple times."""
|
||||
new_cache = cache
|
||||
for i in range(len(cache)):
|
||||
new_cache.key_cache[i] = cache.key_cache[i].clone()
|
||||
new_cache.value_cache[i] = cache.value_cache[i].clone()
|
||||
new_cache.layers[i].keys = cache.layers[i].keys.clone()
|
||||
new_cache.layers[i].values = cache.layers[i].values.clone()
|
||||
|
||||
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
|
||||
# We need to run both on a copy of the cache, otherwise it is modified in-place
|
||||
|
||||
@@ -271,7 +271,7 @@ class T5GemmaModelTester:
|
||||
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertIsNotNone(decoder_past)
|
||||
self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers)
|
||||
self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers)
|
||||
self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers)
|
||||
|
||||
def check_prepare_lm_labels_via_shift_left(
|
||||
self,
|
||||
@@ -1060,9 +1060,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
# 3. Check cache shapes
|
||||
# 3.1. Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
num_cache_decoder_layers = (
|
||||
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
|
||||
)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@@ -1070,30 +1068,30 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
|
||||
self_attention_layer_keys = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
|
||||
)
|
||||
self_attention_layer_value_cache = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
|
||||
self_attention_layer_values = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
|
||||
)
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
# Cross attention (ignore 3rd dim, see default shape preparation)
|
||||
cross_attention_layer_key_cache = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
|
||||
cross_attention_layer_keys = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
|
||||
)
|
||||
cross_attention_layer_value_cache = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
|
||||
cross_attention_layer_values = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
|
||||
)
|
||||
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
|
||||
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
|
||||
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
|
||||
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
|
||||
|
||||
# 3.2. Decoder-only checks
|
||||
else:
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
@@ -1101,10 +1099,10 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
|
||||
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
|
||||
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
|
||||
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
|
||||
|
||||
@unittest.skip("Mismatch issue doesn't exist in T5Gemma.")
|
||||
def test_load_with_mismatched_shapes(self):
|
||||
|
||||
@@ -36,7 +36,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal
|
||||
from transformers.utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -49,8 +49,12 @@ if is_torch_available():
|
||||
DynamicCache,
|
||||
Gemma2Config,
|
||||
GenerationConfig,
|
||||
HQQQuantizedCacheProcessor,
|
||||
HybridCache,
|
||||
HybridChunkedCache,
|
||||
LlamaConfig,
|
||||
QuantizedCache,
|
||||
QuantoQuantizedCacheProcessor,
|
||||
SlidingWindowCache,
|
||||
StaticCache,
|
||||
convert_and_export_with_cache,
|
||||
@@ -252,6 +256,59 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@parameterized.expand([("quanto"), ("HQQ")])
|
||||
def test_quantized_cache_generation(self, backend):
|
||||
"""Tests that QuantizedCache works as expected for both `quanto` and `hqq` backends."""
|
||||
if backend == "quanto":
|
||||
if not is_optimum_quanto_available():
|
||||
self.skipTest("Quanto is not available")
|
||||
axis_key, axis_value = 0, 0
|
||||
# This output is taken from a run with the same parameters, and is known to be correct
|
||||
expected_generation = ["The cat's whiskers are also a sign of anxiety."]
|
||||
elif backend == "HQQ":
|
||||
if not is_hqq_available():
|
||||
self.skipTest("HQQ is not available")
|
||||
axis_key, axis_value = 1, 1
|
||||
# HQQ has slightly different numerics
|
||||
expected_generation = ["The cat's whiskers are also a sign of anxiety."]
|
||||
else:
|
||||
return
|
||||
|
||||
inputs = self.tokenizer(["The cat"], return_tensors="pt").to(self.model.device)
|
||||
|
||||
gen_out = self.model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=10,
|
||||
return_dict_in_generate=True,
|
||||
cache_implementation="quantized",
|
||||
cache_config={
|
||||
"backend": backend,
|
||||
"nbits": 4,
|
||||
"q_group_size": 16,
|
||||
"residual_length": 4,
|
||||
"axis_key": axis_key,
|
||||
"axis_value": axis_value,
|
||||
},
|
||||
disable_compile=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(gen_out.past_key_values, QuantizedCache)
|
||||
processor = gen_out.past_key_values.cache_processor
|
||||
if backend == "quanto":
|
||||
self.assertIsInstance(processor, QuantoQuantizedCacheProcessor)
|
||||
elif backend == "hqq":
|
||||
self.assertIsInstance(processor, HQQQuantizedCacheProcessor)
|
||||
|
||||
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, expected_generation)
|
||||
|
||||
self.assertTrue(len(processor._quantized_keys) > 0)
|
||||
|
||||
# Check that something is actually quantized
|
||||
has_been_quantized = any((q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_keys)
|
||||
self.assertTrue(has_been_quantized)
|
||||
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_extra_left_padding(self, cache_implementation):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the cache"""
|
||||
@@ -566,7 +623,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
past_key_values=DynamicCache(),
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
||||
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
|
||||
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
||||
self.assertEqual(
|
||||
3,
|
||||
@@ -587,11 +644,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5))
|
||||
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
|
||||
|
||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
|
||||
for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers):
|
||||
self.assertTrue(torch.allclose(l1.keys, l2.keys, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(l1.values, l2.values, atol=1e-5))
|
||||
|
||||
def test_dynamic_cache_exportability_multiple_run(self):
|
||||
# When exporting with DynamicCache, you should export two graphs:
|
||||
@@ -615,7 +670,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
past_key_values=DynamicCache(),
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
||||
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
|
||||
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
||||
self.assertEqual(
|
||||
3,
|
||||
@@ -640,9 +695,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
shapes = torch.export.ShapesCollection()
|
||||
dyn = torch.export.Dim("seq", max=512)
|
||||
|
||||
for ix in range(len(past_key_values.key_cache)):
|
||||
shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None)
|
||||
shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None)
|
||||
for ix in range(len(past_key_values)):
|
||||
shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None)
|
||||
shapes[past_key_values.layers[ix].values] = (None, None, dyn, None)
|
||||
|
||||
ep_second = torch.export.export(
|
||||
model,
|
||||
@@ -683,11 +738,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
|
||||
|
||||
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
|
||||
for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers):
|
||||
self.assertTrue(torch.allclose(l1.keys, l2.keys, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(l1.values, l2.values, atol=1e-5))
|
||||
|
||||
@unittest.skip("Runs on my machine locally, passed, no idea why it does not online")
|
||||
def test_static_cache_exportability(self):
|
||||
@@ -726,8 +779,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
|
||||
self.assertEqual(model.generation_config.max_length, max_cache_len)
|
||||
self.assertTrue(model.generation_config.cache_config is not None)
|
||||
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
|
||||
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)
|
||||
self.assertEqual(model.generation_config.cache_config.get("batch_size"), batch_size)
|
||||
self.assertEqual(model.generation_config.cache_config.get("max_cache_len"), max_cache_len)
|
||||
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
|
||||
@@ -830,7 +883,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
head_dim=1,
|
||||
hidden_size=1,
|
||||
sliding_window=self.window_size,
|
||||
sliding_window_pattern=2, # Default pattern for hybrid sliding
|
||||
layer_types=["full_attention"] * 1, # Static cache by default
|
||||
)
|
||||
|
||||
def test_static_cache_out_of_bounds(self):
|
||||
@@ -867,7 +920,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2])},
|
||||
)
|
||||
self.assertEqual(
|
||||
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
|
||||
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
|
||||
)
|
||||
|
||||
# Scenario 2: Fill to capacity
|
||||
@@ -878,7 +931,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
self.assertEqual(
|
||||
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
|
||||
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
|
||||
)
|
||||
|
||||
def test_sliding_window_cache(self):
|
||||
@@ -897,7 +950,9 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
||||
"""
|
||||
# Scenario 1: Update within window, no slide yet
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
config = copy.deepcopy(self.config)
|
||||
config.layer_types = ["sliding_attention"] * config.num_hidden_layers
|
||||
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=prefill,
|
||||
@@ -912,13 +967,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"SlidingWindowCache Scenario 1 failed",
|
||||
)
|
||||
|
||||
# Scenario 2: Update causing slide
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=prefill,
|
||||
@@ -933,13 +988,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[2.0, 3.0, 4.0, 5.0],
|
||||
"SlidingWindowCache Scenario 2 failed",
|
||||
)
|
||||
|
||||
# Scenario 3: Long prompt handling
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=long_prefill,
|
||||
@@ -948,13 +1003,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"SlidingWindowCache Scenario 3 failed",
|
||||
)
|
||||
|
||||
def test_hybrid_cache_static_mode(self):
|
||||
"""Test HybridCache in static mode with hardcoded assertions.
|
||||
"""Test HybridCache with only 1 static layer.
|
||||
|
||||
Scenario 1: Static layer behavior
|
||||
prefill: [1.0, 2.0, 0.0, 0.0]
|
||||
@@ -964,7 +1019,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
update pos 3: [1.0, 2.0, 3.0, 4.0]
|
||||
"""
|
||||
config = copy.deepcopy(self.config)
|
||||
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
|
||||
config.layer_types = ["full_attention"] * config.num_hidden_layers
|
||||
|
||||
# Scenario 1
|
||||
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
@@ -982,7 +1037,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2])},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"HybridCache Static Scenario 1 failed",
|
||||
)
|
||||
@@ -995,7 +1050,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
"HybridCache Static Scenario 2 failed",
|
||||
)
|
||||
@@ -1018,8 +1073,10 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
||||
"""
|
||||
config = copy.deepcopy(self.config)
|
||||
config.layer_types = ["sliding_attention"] * config.num_hidden_layers
|
||||
# Scenario 1: Update within window, no slide yet
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=prefill,
|
||||
@@ -1034,13 +1091,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"HybridCache Sliding Scenario 1 failed",
|
||||
)
|
||||
|
||||
# Scenario 2: Update causing first slide
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=prefill,
|
||||
@@ -1055,7 +1112,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[2.0, 3.0, 4.0, 5.0],
|
||||
"HybridCache Sliding Scenario 2 failed",
|
||||
)
|
||||
@@ -1068,13 +1125,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"HybridCache Sliding Scenario 3 failed",
|
||||
)
|
||||
|
||||
# Scenario 4: Long prompt handling
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=long_prefill,
|
||||
@@ -1083,7 +1140,278 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"HybridCache Sliding Scenario 4 failed",
|
||||
)
|
||||
|
||||
def test_dynamic_cache(self):
|
||||
"""Test DynamicCache with manually prefilled states and hardcoded assertions.
|
||||
Scenario 1: prefill and update for one layer
|
||||
prefill: [1.0, 2.0]
|
||||
update pos 2: [1.0, 2.0, 3.0]
|
||||
Scenario 2: prefill and update for two layers independently
|
||||
"""
|
||||
prefill = torch.tensor([1.0, 2.0])[None, None, :, None]
|
||||
update3 = torch.tensor(3.0)[None, None, None, None]
|
||||
update4 = torch.tensor(4.0)[None, None, None, None]
|
||||
|
||||
# Scenario 1: prefill and update for one layer
|
||||
cache = DynamicCache()
|
||||
cache.update(prefill, prefill, 0)
|
||||
cache.update(update3, update3, 0)
|
||||
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed")
|
||||
cache.update(update4, update4, 0)
|
||||
self.assertEqual(
|
||||
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed"
|
||||
)
|
||||
|
||||
# Scenario 2: prefill and update for two layers independently
|
||||
prefill1 = torch.tensor([10.0, 20.0])[None, None, :, None]
|
||||
update3_1 = torch.tensor(30.0)[None, None, None, None]
|
||||
update4_1 = torch.tensor(40.0)[None, None, None, None]
|
||||
|
||||
cache = DynamicCache()
|
||||
cache.update(prefill, prefill, 0)
|
||||
cache.update(prefill1, prefill1, 1)
|
||||
|
||||
cache.update(update3, update3, 0)
|
||||
cache.update(update3_1, update3_1, 1)
|
||||
cache.update(update4, update4, 0)
|
||||
cache.update(update4_1, update4_1, 1)
|
||||
self.assertEqual(
|
||||
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed"
|
||||
)
|
||||
self.assertEqual(
|
||||
cache.layers[1].keys[0, 0, :, 0].tolist(),
|
||||
[10.0, 20.0, 30.0, 40.0],
|
||||
"DynamicCache Scenario 2 layer 1 failed",
|
||||
)
|
||||
|
||||
def test_hybrid_cache(self):
|
||||
"""
|
||||
Test HybridCache with a mix of static and sliding layers,
|
||||
with prefill size bigger than sliding window.
|
||||
|
||||
prefill:
|
||||
static: [1.0, 2.0, 3.0]
|
||||
sliding: [10.0, 20.0, 30.0]
|
||||
(stores only [20.0, 30.0])
|
||||
|
||||
update pos 4:
|
||||
static: [1.0, 2.0, 3.0, 5.0]
|
||||
sliding: [30.0, 50.0]
|
||||
"""
|
||||
config = copy.deepcopy(self.config)
|
||||
config.num_hidden_layers = 2
|
||||
config.layer_types = ["full_attention", "sliding_attention"]
|
||||
config.sliding_window = 2
|
||||
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
|
||||
# Prefill both layers up to cache capacity
|
||||
prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None]
|
||||
# Sliding window is 2, so it should return full [10.0, 20.0, 30.0], but store only [20.0, 30.0]
|
||||
prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None]
|
||||
|
||||
# Update static layer (layer 0)
|
||||
res_static = hybrid_cache.update(
|
||||
key_states=prefill_static,
|
||||
value_states=prefill_static,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(3)},
|
||||
)
|
||||
|
||||
# Update sliding layer (layer 1)
|
||||
res_sliding = hybrid_cache.update(
|
||||
key_states=prefill_sliding,
|
||||
value_states=prefill_sliding,
|
||||
layer_idx=1,
|
||||
cache_kwargs={"cache_position": torch.arange(3), "sliding_window": self.window_size},
|
||||
)
|
||||
|
||||
# Verify initial states
|
||||
self.assertEqual(
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"Initial static layer state is wrong",
|
||||
)
|
||||
self.assertEqual(
|
||||
res_static[0][0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"Static layer did not return the correct value.",
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(),
|
||||
[20.0, 30.0],
|
||||
"Initial sliding layer state is wrong",
|
||||
)
|
||||
self.assertEqual(
|
||||
res_sliding[0][0, 0, :, 0].tolist(),
|
||||
[10.0, 20.0, 30.0],
|
||||
"Sliding layer did not return the correct value.",
|
||||
)
|
||||
|
||||
# Update at position 4
|
||||
new_key_static = torch.tensor(5.0)[None, None, None, None]
|
||||
new_key_sliding = torch.tensor(50.0)[None, None, None, None]
|
||||
|
||||
# Update static layer (layer 0)
|
||||
hybrid_cache.update(
|
||||
key_states=new_key_static,
|
||||
value_states=new_key_static,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
|
||||
# Update sliding layer (layer 1)
|
||||
hybrid_cache.update(
|
||||
key_states=new_key_sliding,
|
||||
value_states=new_key_sliding,
|
||||
layer_idx=1,
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
|
||||
# The static layer does not slide, so it should have updated the element at position 3
|
||||
self.assertEqual(
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 5.0],
|
||||
"Static layer did not update as expected.",
|
||||
)
|
||||
|
||||
# The sliding layer should have shifted, discarding the first element and adding the new one at the end
|
||||
self.assertEqual(
|
||||
hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(),
|
||||
[30.0, 50.0],
|
||||
"Sliding layer did not slide as expected.",
|
||||
)
|
||||
|
||||
def test_hybrid_chunked_cache(self):
|
||||
"""
|
||||
Test HybridChunkedCache with both static and sliding layers and special cases:
|
||||
1. a pre-fill longer than the sliding window
|
||||
2. a single-token decoding step (normal generation)
|
||||
3. a multi-token decoding step after the window is already full
|
||||
|
||||
Sliding-window size: 2
|
||||
Static layer is full-attention.
|
||||
─────────────────────────────────────────────
|
||||
Prefill:
|
||||
static : [1, 2, 3]
|
||||
sliding : [10, 20, 30] (cache keeps [20, 30])
|
||||
+1 token:
|
||||
static : [1, 2, 3, 5]
|
||||
sliding : [30, 50] (returned [30, 50])
|
||||
+2 tokens:
|
||||
sliding : [60, 70] (returned [50, 60, 70])
|
||||
"""
|
||||
|
||||
config = copy.deepcopy(self.config)
|
||||
config.num_hidden_layers = 2
|
||||
config.layer_types = ["full_attention", "sliding_attention"]
|
||||
config.sliding_window = 2
|
||||
max_cache_len = 4
|
||||
chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len)
|
||||
|
||||
# 1) PREFILL (3 tokens > sliding_window)
|
||||
prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None]
|
||||
prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None]
|
||||
|
||||
res_static = chunked_cache.update(
|
||||
key_states=prefill_static,
|
||||
value_states=prefill_static,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(3)},
|
||||
)
|
||||
res_sliding = chunked_cache.update(
|
||||
key_states=prefill_sliding,
|
||||
value_states=prefill_sliding,
|
||||
layer_idx=1,
|
||||
cache_kwargs={"cache_position": torch.arange(3)},
|
||||
)
|
||||
|
||||
# Static layer keeps everything
|
||||
self.assertEqual(res_static[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0])
|
||||
# Sliding layer returned full prompt but stored the tail
|
||||
self.assertEqual(res_sliding[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0])
|
||||
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [20.0, 30.0])
|
||||
|
||||
# 2) ONE-TOKEN UPDATE (normal decode)
|
||||
new_static = torch.tensor(5.0)[None, None, None, None]
|
||||
new_sliding = torch.tensor(50.0)[None, None, None, None]
|
||||
|
||||
chunked_cache.update(
|
||||
key_states=new_static,
|
||||
value_states=new_static,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
res_one = chunked_cache.update(
|
||||
key_states=new_sliding,
|
||||
value_states=new_sliding,
|
||||
layer_idx=1,
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
|
||||
self.assertEqual(chunked_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 5.0])
|
||||
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [30.0, 50.0])
|
||||
self.assertEqual(res_one[0][0, 0, :, 0].tolist(), [30.0, 50.0])
|
||||
|
||||
# 3) TWO-TOKEN UPDATE after window is full
|
||||
new_sliding_2 = torch.tensor([60.0, 70.0])[None, None, :, None]
|
||||
res_two = chunked_cache.update(
|
||||
key_states=new_sliding_2,
|
||||
value_states=new_sliding_2,
|
||||
layer_idx=1,
|
||||
cache_kwargs={"cache_position": torch.tensor([4, 5])}, # arbitrary positions; ignored in full mode
|
||||
)
|
||||
|
||||
# Cache now keeps the latest two tokens
|
||||
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [60.0, 70.0])
|
||||
# Returned tensor contains previous last token + new ones
|
||||
self.assertEqual(res_two[0][0, 0, :, 0].tolist(), [50.0, 60.0, 70.0])
|
||||
|
||||
def test_hybrid_chunked_cache_extra_cases(self):
|
||||
"""
|
||||
Covers the new cases that appear on prefill chunking:
|
||||
1) Not full multi-token update (cache_position[0] + update_len <= max_cache_len)
|
||||
2) Multi-token update crossing the window (cache_position[0] < max_cache_len and cache_position[0] + update_len > max_cache_len)
|
||||
|
||||
Single sliding layer, max_cache_len = 3.
|
||||
|
||||
Step 0 (prefill 2 tokens, update_len < max_cache_len
|
||||
cache = [10, 20, 0] returned [10, 20, 0]
|
||||
|
||||
Step 1 (add 2 tokens, p = 2, update_len = 2, p + update_len = 4 > max_cache_len)
|
||||
cache = [20, 30, 40] returned [10, 20, 30, 40]
|
||||
"""
|
||||
|
||||
config = copy.deepcopy(self.config)
|
||||
config.num_hidden_layers = 1
|
||||
config.layer_types = ["sliding_attention"]
|
||||
config.sliding_window = 3
|
||||
cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3)
|
||||
|
||||
# Step 0 : multi-token prefill
|
||||
first_chunk = torch.tensor([10.0, 20.0])[None, None, :, None] # L = 2
|
||||
returned_0 = cache.update(
|
||||
key_states=first_chunk,
|
||||
value_states=first_chunk,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(2)}, # p = 0,1
|
||||
)
|
||||
|
||||
# internal cache should have first two tokens and a zero pad
|
||||
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [10.0, 20.0, 0.0])
|
||||
self.assertEqual(returned_0[0][0, 0, :, 0].tolist(), [10.0, 20.0, 0.0])
|
||||
|
||||
# Step 1 : multi-token update crossing the window boundary
|
||||
second_chunk = torch.tensor([30.0, 40.0])[None, None, :, None] # L = 2
|
||||
returned_1 = cache.update(
|
||||
key_states=second_chunk,
|
||||
value_states=second_chunk,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([2, 3])}, # p = 2
|
||||
)
|
||||
|
||||
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [20.0, 30.0, 40.0])
|
||||
self.assertEqual(returned_1[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0])
|
||||
|
||||
Reference in New Issue
Block a user