[cache refactor] Move all the caching logic to a per-layer approach (#39106)

* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077)

- Introduces CacheLayer and Cache base classes
- Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers
- Implements method/attr dispatch across layers to reduce boilerplate
- Adds CacheProcessor hooks for offloading, quantization, etc.
- Updates and passes tests

* fix quantized, add tests

* remove CacheProcessorList

* raushan review, arthur review

* joao review: minor things

* remove cache configs, make CacheLayer a mixin (joaos review)

* back to storage inside Cache()

* remove cachebase for decorator

* no more __getattr__

* fix tests

* joaos review except docs

* fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant`

More verbose exceptions in `fix_docstring` on docstring formatting issues.

* Revert "back to storage inside Cache()"

This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913.

* cyril review

* simplify cache export

* fix lfm2 cache

* HybridChunked to layer

* BC proxy object for cache.key_cache[i]=...

* reorder classes

* bfff come on LFM2

* better tests for hybrid and hybridChunked

* complete coverage for hybrid chunked caches (prefill chunking)

* reimplementing HybridChunked

* cyril review

* fix ci

* docs for cache refactor

* docs

* oopsie

* oopsie

* fix after merge

* cyril review

* arthur review

* opsie

* fix lfm2

* opsie2
This commit is contained in:
Manuel de Prada Corral
2025-07-22 16:10:25 +02:00
committed by GitHub
parent b16688e96a
commit c338fd43b0
64 changed files with 2779 additions and 2441 deletions

View File

@@ -1577,9 +1577,7 @@ class GenerationTesterMixin:
# 3. Check cache shapes
# 3.1. Encoder-Decoder checks
if config.is_encoder_decoder:
num_cache_decoder_layers = (
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@@ -1587,30 +1585,30 @@ class GenerationTesterMixin:
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
self_attention_layer_keys = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
)
self_attention_layer_value_cache = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
self_attention_layer_values = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
)
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
# Cross attention (ignore 3rd dim, see default shape preparation)
cross_attention_layer_key_cache = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
cross_attention_layer_keys = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
)
cross_attention_layer_value_cache = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
cross_attention_layer_values = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
)
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
num_cache_decoder_layers = len(past_kv)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@@ -1618,10 +1616,18 @@ class GenerationTesterMixin:
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
if is_legacy_cache:
self_attention_layer_keys = past_kv[i][0]
self_attention_layer_values = past_kv[i][1]
elif getattr(past_kv, "layers", None) is None:
# Cache is lot layered (i.e, Mamba derivatives)
self_attention_layer_keys = past_kv.key_cache[i]
self_attention_layer_values = past_kv.value_cache[i]
else:
self_attention_layer_keys = past_kv.layers[i].keys
self_attention_layer_values = past_kv.layers[i].values
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
@pytest.mark.generate
def test_generate_from_random_inputs_embeds(self):
@@ -1804,8 +1810,8 @@ class GenerationTesterMixin:
max_length = max_new_tokens + inputs_embeds.shape[1] - 1
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
self.assertIsInstance(outputs.past_key_values, StaticCache)
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
self.assertEqual(len(outputs.past_key_values), num_hidden_layers)
self.assertListEqual(list(outputs.past_key_values.layers[0].keys.shape), cache_shape)
@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
@@ -2027,8 +2033,8 @@ class GenerationTesterMixin:
num_hidden_layers = text_config.num_hidden_layers
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
self.assertTrue(len(static_cache_generation.past_key_values) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.layers[0].keys.shape == cache_shape)
# Check 2: The outputs must be similar to the case with dynamic cache
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
@@ -2629,12 +2635,12 @@ class GenerationTesterMixin:
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(decoder_past_key_values.key_cache),
[layer.keys.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_shape] * len(decoder_past_key_values.value_cache),
[layer.values.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
@@ -4040,13 +4046,13 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(isinstance(results.past_key_values, StaticCache))
# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
keys_0 = results.past_key_values.layers[0].keys
values_0 = results.past_key_values.layers[0].values
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
keys_1 = results.past_key_values.layers[1].keys
values_1 = results.past_key_values.layers[1].values
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
@pytest.mark.generate
@require_torch_multi_accelerator
@@ -4118,13 +4124,13 @@ class GenerationIntegrationTests(unittest.TestCase):
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
keys_0 = results.past_key_values.layers[0].keys
values_0 = results.past_key_values.layers[0].values
self.assertTrue(keys_0.device == values_0.device == torch.device(0))
key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
keys_1 = results.past_key_values.layers[1].keys
values_1 = results.past_key_values.layers[1].values
self.assertTrue(keys_1.device == values_1.device == torch.device(1))
@slow
def test_padding_input_contrastive_search_gpt2(self):

View File

@@ -168,14 +168,9 @@ class DeepseekV2ModelTest(CausalLMModelTest, unittest.TestCase):
expected_value_shape = expected_common_shape + (config.v_head_dim,)
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_key_shape] * len(decoder_past_key_values.key_cache),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_value_shape] * len(decoder_past_key_values.value_cache),
)
for layer in decoder_past_key_values.layers:
self.assertEqual(layer.keys.shape, expected_key_shape)
self.assertEqual(layer.values.shape, expected_value_shape)
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
def test_generate_compilation_all_outputs(self):

View File

@@ -440,13 +440,11 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# difference: last dim
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
v_embed_dim = config.v_head_dim
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
self_attention_keys_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
self_attention_values_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
# build the full cache shapes
num_hidden_layers = config.num_hidden_layers
all_cache_shapes = [
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
]
all_cache_shapes = [[self_attention_keys_shape, self_attention_values_shape] for _ in range(num_hidden_layers)]
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
@require_torch_large_accelerator

View File

@@ -399,12 +399,12 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(decoder_past_key_values.key_cache),
[layer.keys.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_shape] * len(decoder_past_key_values.value_cache),
[layer.values.shape for layer in decoder_past_key_values.layers],
[expected_shape] * len(decoder_past_key_values.layers),
)
def _check_scores(self, batch_size, scores, generated_length, config):

View File

@@ -38,7 +38,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model
from transformers import AutoTokenizer, Cache, FalconH1ForCausalLM, FalconH1Model
from transformers.models.falcon_h1.modeling_falcon_h1 import (
FalconHybridMambaAttentionDynamicCache,
)
@@ -272,6 +272,43 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
{"feature-extraction": FalconH1Model, "text-generation": FalconH1ForCausalLM} if is_torch_available() else {}
)
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
# (batch, head, seq_length, head_features)
expected_shape = (
batch_size,
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
cache_length,
config.hidden_size // config.num_attention_heads,
)
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(decoder_past_key_values.key_cache),
)
self.assertListEqual(
[value_cache.shape for value_cache in decoder_past_key_values.value_cache],
[expected_shape] * len(decoder_past_key_values.value_cache),
)
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
else:
self.assertListEqual(
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
[True] * len(decoder_past_key_values),
)
# check shape key, value
self.assertListEqual(
[layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values],
[expected_shape] * len(decoder_past_key_values),
)
self.assertListEqual(
[layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values],
[expected_shape] * len(decoder_past_key_values),
)
def setUp(self):
self.model_tester = FalconH1ModelTester(self)
self.config_tester = ConfigTester(self, config_class=FalconH1Config, hidden_size=64)

View File

@@ -235,8 +235,8 @@ class GPTNeoXModelTester:
"""Deep copy a DynamicCache to reuse the same one multiple times."""
new_cache = cache
for i in range(len(cache)):
new_cache.key_cache[i] = cache.key_cache[i].clone()
new_cache.value_cache[i] = cache.value_cache[i].clone()
new_cache.layers[i].keys = cache.layers[i].keys.clone()
new_cache.layers[i].values = cache.layers[i].values.clone()
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
# We need to run both on a copy of the cache, otherwise it is modified in-place

View File

@@ -271,7 +271,7 @@ class T5GemmaModelTester:
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertIsNotNone(decoder_past)
self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers)
self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers)
self.parent.assertEqual(len(decoder_past.cross_attention_cache), config.decoder.num_hidden_layers)
def check_prepare_lm_labels_via_shift_left(
self,
@@ -1060,9 +1060,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# 3. Check cache shapes
# 3.1. Encoder-Decoder checks
if config.is_encoder_decoder:
num_cache_decoder_layers = (
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@@ -1070,30 +1068,30 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
self_attention_layer_keys = (
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.layers[i].keys
)
self_attention_layer_value_cache = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
self_attention_layer_values = (
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.layers[i].values
)
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
# Cross attention (ignore 3rd dim, see default shape preparation)
cross_attention_layer_key_cache = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
cross_attention_layer_keys = (
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].keys
)
cross_attention_layer_value_cache = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
cross_attention_layer_values = (
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.layers[i].values
)
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
cross_attention_layer_keys = cross_attention_layer_keys[:, :, 0, :]
cross_attention_layer_values = cross_attention_layer_values[:, :, 0, :]
self.assertEqual(cross_attention_layer_keys.shape, all_cache_shapes[i][2])
self.assertEqual(cross_attention_layer_values.shape, all_cache_shapes[i][3])
# 3.2. Decoder-only checks
else:
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv)
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
for i in range(num_decoder_layers):
@@ -1101,10 +1099,10 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
self_attention_layer_keys = past_kv[i][0] if is_legacy_cache else past_kv.layers[i].keys
self_attention_layer_values = past_kv[i][1] if is_legacy_cache else past_kv.layers[i].values
self.assertEqual(self_attention_layer_keys.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_values.shape, all_cache_shapes[i][1])
@unittest.skip("Mismatch issue doesn't exist in T5Gemma.")
def test_load_with_mismatched_shapes(self):

View File

@@ -36,7 +36,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal
from transformers.utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal
if is_torch_available():
@@ -49,8 +49,12 @@ if is_torch_available():
DynamicCache,
Gemma2Config,
GenerationConfig,
HQQQuantizedCacheProcessor,
HybridCache,
HybridChunkedCache,
LlamaConfig,
QuantizedCache,
QuantoQuantizedCacheProcessor,
SlidingWindowCache,
StaticCache,
convert_and_export_with_cache,
@@ -252,6 +256,59 @@ class CacheIntegrationTest(unittest.TestCase):
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
@parameterized.expand([("quanto"), ("HQQ")])
def test_quantized_cache_generation(self, backend):
"""Tests that QuantizedCache works as expected for both `quanto` and `hqq` backends."""
if backend == "quanto":
if not is_optimum_quanto_available():
self.skipTest("Quanto is not available")
axis_key, axis_value = 0, 0
# This output is taken from a run with the same parameters, and is known to be correct
expected_generation = ["The cat's whiskers are also a sign of anxiety."]
elif backend == "HQQ":
if not is_hqq_available():
self.skipTest("HQQ is not available")
axis_key, axis_value = 1, 1
# HQQ has slightly different numerics
expected_generation = ["The cat's whiskers are also a sign of anxiety."]
else:
return
inputs = self.tokenizer(["The cat"], return_tensors="pt").to(self.model.device)
gen_out = self.model.generate(
**inputs,
do_sample=False,
max_new_tokens=10,
return_dict_in_generate=True,
cache_implementation="quantized",
cache_config={
"backend": backend,
"nbits": 4,
"q_group_size": 16,
"residual_length": 4,
"axis_key": axis_key,
"axis_value": axis_value,
},
disable_compile=True,
)
self.assertIsInstance(gen_out.past_key_values, QuantizedCache)
processor = gen_out.past_key_values.cache_processor
if backend == "quanto":
self.assertIsInstance(processor, QuantoQuantizedCacheProcessor)
elif backend == "hqq":
self.assertIsInstance(processor, HQQQuantizedCacheProcessor)
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
self.assertListEqual(decoded, expected_generation)
self.assertTrue(len(processor._quantized_keys) > 0)
# Check that something is actually quantized
has_been_quantized = any((q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_keys)
self.assertTrue(has_been_quantized)
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the cache"""
@@ -566,7 +623,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
past_key_values=DynamicCache(),
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
self.assertEqual(
3,
@@ -587,11 +644,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
use_cache=True,
)
self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5))
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers):
self.assertTrue(torch.allclose(l1.keys, l2.keys, atol=1e-5))
self.assertTrue(torch.allclose(l1.values, l2.values, atol=1e-5))
def test_dynamic_cache_exportability_multiple_run(self):
# When exporting with DynamicCache, you should export two graphs:
@@ -615,7 +670,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
past_key_values=DynamicCache(),
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
self.assertEqual(
3,
@@ -640,9 +695,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
shapes = torch.export.ShapesCollection()
dyn = torch.export.Dim("seq", max=512)
for ix in range(len(past_key_values.key_cache)):
shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None)
shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None)
for ix in range(len(past_key_values)):
shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None)
shapes[past_key_values.layers[ix].values] = (None, None, dyn, None)
ep_second = torch.export.export(
model,
@@ -683,11 +738,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
use_cache=True,
)
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers):
self.assertTrue(torch.allclose(l1.keys, l2.keys, atol=1e-5))
self.assertTrue(torch.allclose(l1.values, l2.values, atol=1e-5))
@unittest.skip("Runs on my machine locally, passed, no idea why it does not online")
def test_static_cache_exportability(self):
@@ -726,8 +779,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
self.assertEqual(model.generation_config.max_length, max_cache_len)
self.assertTrue(model.generation_config.cache_config is not None)
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)
self.assertEqual(model.generation_config.cache_config.get("batch_size"), batch_size)
self.assertEqual(model.generation_config.cache_config.get("max_cache_len"), max_cache_len)
exported_program = convert_and_export_with_cache(model)
@@ -830,7 +883,7 @@ class SyntheticCacheTest(unittest.TestCase):
head_dim=1,
hidden_size=1,
sliding_window=self.window_size,
sliding_window_pattern=2, # Default pattern for hybrid sliding
layer_types=["full_attention"] * 1, # Static cache by default
)
def test_static_cache_out_of_bounds(self):
@@ -867,7 +920,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2])},
)
self.assertEqual(
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
)
# Scenario 2: Fill to capacity
@@ -878,7 +931,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
)
def test_sliding_window_cache(self):
@@ -897,7 +950,9 @@ class SyntheticCacheTest(unittest.TestCase):
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
"""
# Scenario 1: Update within window, no slide yet
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
config = copy.deepcopy(self.config)
config.layer_types = ["sliding_attention"] * config.num_hidden_layers
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
sliding_cache.update(
key_states=prefill,
@@ -912,13 +967,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"SlidingWindowCache Scenario 1 failed",
)
# Scenario 2: Update causing slide
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
sliding_cache.update(
key_states=prefill,
@@ -933,13 +988,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
[2.0, 3.0, 4.0, 5.0],
"SlidingWindowCache Scenario 2 failed",
)
# Scenario 3: Long prompt handling
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
sliding_cache.update(
key_states=long_prefill,
@@ -948,13 +1003,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"SlidingWindowCache Scenario 3 failed",
)
def test_hybrid_cache_static_mode(self):
"""Test HybridCache in static mode with hardcoded assertions.
"""Test HybridCache with only 1 static layer.
Scenario 1: Static layer behavior
prefill: [1.0, 2.0, 0.0, 0.0]
@@ -964,7 +1019,7 @@ class SyntheticCacheTest(unittest.TestCase):
update pos 3: [1.0, 2.0, 3.0, 4.0]
"""
config = copy.deepcopy(self.config)
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
config.layer_types = ["full_attention"] * config.num_hidden_layers
# Scenario 1
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
@@ -982,7 +1037,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2])},
)
self.assertEqual(
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"HybridCache Static Scenario 1 failed",
)
@@ -995,7 +1050,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 4.0],
"HybridCache Static Scenario 2 failed",
)
@@ -1018,8 +1073,10 @@ class SyntheticCacheTest(unittest.TestCase):
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
"""
config = copy.deepcopy(self.config)
config.layer_types = ["sliding_attention"] * config.num_hidden_layers
# Scenario 1: Update within window, no slide yet
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
hybrid_cache.update(
key_states=prefill,
@@ -1034,13 +1091,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"HybridCache Sliding Scenario 1 failed",
)
# Scenario 2: Update causing first slide
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
hybrid_cache.update(
key_states=prefill,
@@ -1055,7 +1112,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[2.0, 3.0, 4.0, 5.0],
"HybridCache Sliding Scenario 2 failed",
)
@@ -1068,13 +1125,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"HybridCache Sliding Scenario 3 failed",
)
# Scenario 4: Long prompt handling
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
hybrid_cache.update(
key_states=long_prefill,
@@ -1083,7 +1140,278 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"HybridCache Sliding Scenario 4 failed",
)
def test_dynamic_cache(self):
"""Test DynamicCache with manually prefilled states and hardcoded assertions.
Scenario 1: prefill and update for one layer
prefill: [1.0, 2.0]
update pos 2: [1.0, 2.0, 3.0]
Scenario 2: prefill and update for two layers independently
"""
prefill = torch.tensor([1.0, 2.0])[None, None, :, None]
update3 = torch.tensor(3.0)[None, None, None, None]
update4 = torch.tensor(4.0)[None, None, None, None]
# Scenario 1: prefill and update for one layer
cache = DynamicCache()
cache.update(prefill, prefill, 0)
cache.update(update3, update3, 0)
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed")
cache.update(update4, update4, 0)
self.assertEqual(
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed"
)
# Scenario 2: prefill and update for two layers independently
prefill1 = torch.tensor([10.0, 20.0])[None, None, :, None]
update3_1 = torch.tensor(30.0)[None, None, None, None]
update4_1 = torch.tensor(40.0)[None, None, None, None]
cache = DynamicCache()
cache.update(prefill, prefill, 0)
cache.update(prefill1, prefill1, 1)
cache.update(update3, update3, 0)
cache.update(update3_1, update3_1, 1)
cache.update(update4, update4, 0)
cache.update(update4_1, update4_1, 1)
self.assertEqual(
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed"
)
self.assertEqual(
cache.layers[1].keys[0, 0, :, 0].tolist(),
[10.0, 20.0, 30.0, 40.0],
"DynamicCache Scenario 2 layer 1 failed",
)
def test_hybrid_cache(self):
"""
Test HybridCache with a mix of static and sliding layers,
with prefill size bigger than sliding window.
prefill:
static: [1.0, 2.0, 3.0]
sliding: [10.0, 20.0, 30.0]
(stores only [20.0, 30.0])
update pos 4:
static: [1.0, 2.0, 3.0, 5.0]
sliding: [30.0, 50.0]
"""
config = copy.deepcopy(self.config)
config.num_hidden_layers = 2
config.layer_types = ["full_attention", "sliding_attention"]
config.sliding_window = 2
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
# Prefill both layers up to cache capacity
prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None]
# Sliding window is 2, so it should return full [10.0, 20.0, 30.0], but store only [20.0, 30.0]
prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None]
# Update static layer (layer 0)
res_static = hybrid_cache.update(
key_states=prefill_static,
value_states=prefill_static,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(3)},
)
# Update sliding layer (layer 1)
res_sliding = hybrid_cache.update(
key_states=prefill_sliding,
value_states=prefill_sliding,
layer_idx=1,
cache_kwargs={"cache_position": torch.arange(3), "sliding_window": self.window_size},
)
# Verify initial states
self.assertEqual(
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"Initial static layer state is wrong",
)
self.assertEqual(
res_static[0][0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"Static layer did not return the correct value.",
)
self.assertEqual(
hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(),
[20.0, 30.0],
"Initial sliding layer state is wrong",
)
self.assertEqual(
res_sliding[0][0, 0, :, 0].tolist(),
[10.0, 20.0, 30.0],
"Sliding layer did not return the correct value.",
)
# Update at position 4
new_key_static = torch.tensor(5.0)[None, None, None, None]
new_key_sliding = torch.tensor(50.0)[None, None, None, None]
# Update static layer (layer 0)
hybrid_cache.update(
key_states=new_key_static,
value_states=new_key_static,
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([3])},
)
# Update sliding layer (layer 1)
hybrid_cache.update(
key_states=new_key_sliding,
value_states=new_key_sliding,
layer_idx=1,
cache_kwargs={"cache_position": torch.tensor([3])},
)
# The static layer does not slide, so it should have updated the element at position 3
self.assertEqual(
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 5.0],
"Static layer did not update as expected.",
)
# The sliding layer should have shifted, discarding the first element and adding the new one at the end
self.assertEqual(
hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(),
[30.0, 50.0],
"Sliding layer did not slide as expected.",
)
def test_hybrid_chunked_cache(self):
"""
Test HybridChunkedCache with both static and sliding layers and special cases:
1. a pre-fill longer than the sliding window
2. a single-token decoding step (normal generation)
3. a multi-token decoding step after the window is already full
Sliding-window size: 2
Static layer is full-attention.
─────────────────────────────────────────────
Prefill:
static : [1, 2, 3]
sliding : [10, 20, 30] (cache keeps [20, 30])
+1 token:
static : [1, 2, 3, 5]
sliding : [30, 50] (returned [30, 50])
+2 tokens:
sliding : [60, 70] (returned [50, 60, 70])
"""
config = copy.deepcopy(self.config)
config.num_hidden_layers = 2
config.layer_types = ["full_attention", "sliding_attention"]
config.sliding_window = 2
max_cache_len = 4
chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len)
# 1) PREFILL (3 tokens > sliding_window)
prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None]
prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None]
res_static = chunked_cache.update(
key_states=prefill_static,
value_states=prefill_static,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(3)},
)
res_sliding = chunked_cache.update(
key_states=prefill_sliding,
value_states=prefill_sliding,
layer_idx=1,
cache_kwargs={"cache_position": torch.arange(3)},
)
# Static layer keeps everything
self.assertEqual(res_static[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0])
# Sliding layer returned full prompt but stored the tail
self.assertEqual(res_sliding[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0])
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [20.0, 30.0])
# 2) ONE-TOKEN UPDATE (normal decode)
new_static = torch.tensor(5.0)[None, None, None, None]
new_sliding = torch.tensor(50.0)[None, None, None, None]
chunked_cache.update(
key_states=new_static,
value_states=new_static,
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([3])},
)
res_one = chunked_cache.update(
key_states=new_sliding,
value_states=new_sliding,
layer_idx=1,
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(chunked_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 5.0])
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [30.0, 50.0])
self.assertEqual(res_one[0][0, 0, :, 0].tolist(), [30.0, 50.0])
# 3) TWO-TOKEN UPDATE after window is full
new_sliding_2 = torch.tensor([60.0, 70.0])[None, None, :, None]
res_two = chunked_cache.update(
key_states=new_sliding_2,
value_states=new_sliding_2,
layer_idx=1,
cache_kwargs={"cache_position": torch.tensor([4, 5])}, # arbitrary positions; ignored in full mode
)
# Cache now keeps the latest two tokens
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [60.0, 70.0])
# Returned tensor contains previous last token + new ones
self.assertEqual(res_two[0][0, 0, :, 0].tolist(), [50.0, 60.0, 70.0])
def test_hybrid_chunked_cache_extra_cases(self):
"""
Covers the new cases that appear on prefill chunking:
1) Not full multi-token update (cache_position[0] + update_len <= max_cache_len)
2) Multi-token update crossing the window (cache_position[0] < max_cache_len and cache_position[0] + update_len > max_cache_len)
Single sliding layer, max_cache_len = 3.
Step 0 (prefill 2 tokens, update_len < max_cache_len
cache = [10, 20, 0] returned [10, 20, 0]
Step 1 (add 2 tokens, p = 2, update_len = 2, p + update_len = 4 > max_cache_len)
cache = [20, 30, 40] returned [10, 20, 30, 40]
"""
config = copy.deepcopy(self.config)
config.num_hidden_layers = 1
config.layer_types = ["sliding_attention"]
config.sliding_window = 3
cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3)
# Step 0 : multi-token prefill
first_chunk = torch.tensor([10.0, 20.0])[None, None, :, None] # L = 2
returned_0 = cache.update(
key_states=first_chunk,
value_states=first_chunk,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(2)}, # p = 0,1
)
# internal cache should have first two tokens and a zero pad
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [10.0, 20.0, 0.0])
self.assertEqual(returned_0[0][0, 0, :, 0].tolist(), [10.0, 20.0, 0.0])
# Step 1 : multi-token update crossing the window boundary
second_chunk = torch.tensor([30.0, 40.0])[None, None, :, None] # L = 2
returned_1 = cache.update(
key_states=second_chunk,
value_states=second_chunk,
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([2, 3])}, # p = 2
)
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [20.0, 30.0, 40.0])
self.assertEqual(returned_1[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0])