[cache refactor] Move all the caching logic to a per-layer approach (#39106)

* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077)

- Introduces CacheLayer and Cache base classes
- Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers
- Implements method/attr dispatch across layers to reduce boilerplate
- Adds CacheProcessor hooks for offloading, quantization, etc.
- Updates and passes tests

* fix quantized, add tests

* remove CacheProcessorList

* raushan review, arthur review

* joao review: minor things

* remove cache configs, make CacheLayer a mixin (joaos review)

* back to storage inside Cache()

* remove cachebase for decorator

* no more __getattr__

* fix tests

* joaos review except docs

* fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant`

More verbose exceptions in `fix_docstring` on docstring formatting issues.

* Revert "back to storage inside Cache()"

This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913.

* cyril review

* simplify cache export

* fix lfm2 cache

* HybridChunked to layer

* BC proxy object for cache.key_cache[i]=...

* reorder classes

* bfff come on LFM2

* better tests for hybrid and hybridChunked

* complete coverage for hybrid chunked caches (prefill chunking)

* reimplementing HybridChunked

* cyril review

* fix ci

* docs for cache refactor

* docs

* oopsie

* oopsie

* fix after merge

* cyril review

* arthur review

* opsie

* fix lfm2

* opsie2
This commit is contained in:
Manuel de Prada Corral
2025-07-22 16:10:25 +02:00
committed by GitHub
parent b16688e96a
commit c338fd43b0
64 changed files with 2779 additions and 2441 deletions

View File

@@ -36,7 +36,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal
from transformers.utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal
if is_torch_available():
@@ -49,8 +49,12 @@ if is_torch_available():
DynamicCache,
Gemma2Config,
GenerationConfig,
HQQQuantizedCacheProcessor,
HybridCache,
HybridChunkedCache,
LlamaConfig,
QuantizedCache,
QuantoQuantizedCacheProcessor,
SlidingWindowCache,
StaticCache,
convert_and_export_with_cache,
@@ -252,6 +256,59 @@ class CacheIntegrationTest(unittest.TestCase):
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
@parameterized.expand([("quanto"), ("HQQ")])
def test_quantized_cache_generation(self, backend):
"""Tests that QuantizedCache works as expected for both `quanto` and `hqq` backends."""
if backend == "quanto":
if not is_optimum_quanto_available():
self.skipTest("Quanto is not available")
axis_key, axis_value = 0, 0
# This output is taken from a run with the same parameters, and is known to be correct
expected_generation = ["The cat's whiskers are also a sign of anxiety."]
elif backend == "HQQ":
if not is_hqq_available():
self.skipTest("HQQ is not available")
axis_key, axis_value = 1, 1
# HQQ has slightly different numerics
expected_generation = ["The cat's whiskers are also a sign of anxiety."]
else:
return
inputs = self.tokenizer(["The cat"], return_tensors="pt").to(self.model.device)
gen_out = self.model.generate(
**inputs,
do_sample=False,
max_new_tokens=10,
return_dict_in_generate=True,
cache_implementation="quantized",
cache_config={
"backend": backend,
"nbits": 4,
"q_group_size": 16,
"residual_length": 4,
"axis_key": axis_key,
"axis_value": axis_value,
},
disable_compile=True,
)
self.assertIsInstance(gen_out.past_key_values, QuantizedCache)
processor = gen_out.past_key_values.cache_processor
if backend == "quanto":
self.assertIsInstance(processor, QuantoQuantizedCacheProcessor)
elif backend == "hqq":
self.assertIsInstance(processor, HQQQuantizedCacheProcessor)
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
self.assertListEqual(decoded, expected_generation)
self.assertTrue(len(processor._quantized_keys) > 0)
# Check that something is actually quantized
has_been_quantized = any((q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_keys)
self.assertTrue(has_been_quantized)
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the cache"""
@@ -566,7 +623,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
past_key_values=DynamicCache(),
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
self.assertEqual(
3,
@@ -587,11 +644,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
use_cache=True,
)
self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5))
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers):
self.assertTrue(torch.allclose(l1.keys, l2.keys, atol=1e-5))
self.assertTrue(torch.allclose(l1.values, l2.values, atol=1e-5))
def test_dynamic_cache_exportability_multiple_run(self):
# When exporting with DynamicCache, you should export two graphs:
@@ -615,7 +670,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
past_key_values=DynamicCache(),
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
self.assertEqual(
3,
@@ -640,9 +695,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
shapes = torch.export.ShapesCollection()
dyn = torch.export.Dim("seq", max=512)
for ix in range(len(past_key_values.key_cache)):
shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None)
shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None)
for ix in range(len(past_key_values)):
shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None)
shapes[past_key_values.layers[ix].values] = (None, None, dyn, None)
ep_second = torch.export.export(
model,
@@ -683,11 +738,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
use_cache=True,
)
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers):
self.assertTrue(torch.allclose(l1.keys, l2.keys, atol=1e-5))
self.assertTrue(torch.allclose(l1.values, l2.values, atol=1e-5))
@unittest.skip("Runs on my machine locally, passed, no idea why it does not online")
def test_static_cache_exportability(self):
@@ -726,8 +779,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
self.assertEqual(model.generation_config.max_length, max_cache_len)
self.assertTrue(model.generation_config.cache_config is not None)
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)
self.assertEqual(model.generation_config.cache_config.get("batch_size"), batch_size)
self.assertEqual(model.generation_config.cache_config.get("max_cache_len"), max_cache_len)
exported_program = convert_and_export_with_cache(model)
@@ -830,7 +883,7 @@ class SyntheticCacheTest(unittest.TestCase):
head_dim=1,
hidden_size=1,
sliding_window=self.window_size,
sliding_window_pattern=2, # Default pattern for hybrid sliding
layer_types=["full_attention"] * 1, # Static cache by default
)
def test_static_cache_out_of_bounds(self):
@@ -867,7 +920,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2])},
)
self.assertEqual(
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
)
# Scenario 2: Fill to capacity
@@ -878,7 +931,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
)
def test_sliding_window_cache(self):
@@ -897,7 +950,9 @@ class SyntheticCacheTest(unittest.TestCase):
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
"""
# Scenario 1: Update within window, no slide yet
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
config = copy.deepcopy(self.config)
config.layer_types = ["sliding_attention"] * config.num_hidden_layers
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
sliding_cache.update(
key_states=prefill,
@@ -912,13 +967,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"SlidingWindowCache Scenario 1 failed",
)
# Scenario 2: Update causing slide
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
sliding_cache.update(
key_states=prefill,
@@ -933,13 +988,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
[2.0, 3.0, 4.0, 5.0],
"SlidingWindowCache Scenario 2 failed",
)
# Scenario 3: Long prompt handling
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
sliding_cache.update(
key_states=long_prefill,
@@ -948,13 +1003,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"SlidingWindowCache Scenario 3 failed",
)
def test_hybrid_cache_static_mode(self):
"""Test HybridCache in static mode with hardcoded assertions.
"""Test HybridCache with only 1 static layer.
Scenario 1: Static layer behavior
prefill: [1.0, 2.0, 0.0, 0.0]
@@ -964,7 +1019,7 @@ class SyntheticCacheTest(unittest.TestCase):
update pos 3: [1.0, 2.0, 3.0, 4.0]
"""
config = copy.deepcopy(self.config)
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
config.layer_types = ["full_attention"] * config.num_hidden_layers
# Scenario 1
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
@@ -982,7 +1037,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2])},
)
self.assertEqual(
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"HybridCache Static Scenario 1 failed",
)
@@ -995,7 +1050,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 4.0],
"HybridCache Static Scenario 2 failed",
)
@@ -1018,8 +1073,10 @@ class SyntheticCacheTest(unittest.TestCase):
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
"""
config = copy.deepcopy(self.config)
config.layer_types = ["sliding_attention"] * config.num_hidden_layers
# Scenario 1: Update within window, no slide yet
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
hybrid_cache.update(
key_states=prefill,
@@ -1034,13 +1091,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"HybridCache Sliding Scenario 1 failed",
)
# Scenario 2: Update causing first slide
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
hybrid_cache.update(
key_states=prefill,
@@ -1055,7 +1112,7 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[2.0, 3.0, 4.0, 5.0],
"HybridCache Sliding Scenario 2 failed",
)
@@ -1068,13 +1125,13 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"HybridCache Sliding Scenario 3 failed",
)
# Scenario 4: Long prompt handling
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
hybrid_cache.update(
key_states=long_prefill,
@@ -1083,7 +1140,278 @@ class SyntheticCacheTest(unittest.TestCase):
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"HybridCache Sliding Scenario 4 failed",
)
def test_dynamic_cache(self):
"""Test DynamicCache with manually prefilled states and hardcoded assertions.
Scenario 1: prefill and update for one layer
prefill: [1.0, 2.0]
update pos 2: [1.0, 2.0, 3.0]
Scenario 2: prefill and update for two layers independently
"""
prefill = torch.tensor([1.0, 2.0])[None, None, :, None]
update3 = torch.tensor(3.0)[None, None, None, None]
update4 = torch.tensor(4.0)[None, None, None, None]
# Scenario 1: prefill and update for one layer
cache = DynamicCache()
cache.update(prefill, prefill, 0)
cache.update(update3, update3, 0)
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed")
cache.update(update4, update4, 0)
self.assertEqual(
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed"
)
# Scenario 2: prefill and update for two layers independently
prefill1 = torch.tensor([10.0, 20.0])[None, None, :, None]
update3_1 = torch.tensor(30.0)[None, None, None, None]
update4_1 = torch.tensor(40.0)[None, None, None, None]
cache = DynamicCache()
cache.update(prefill, prefill, 0)
cache.update(prefill1, prefill1, 1)
cache.update(update3, update3, 0)
cache.update(update3_1, update3_1, 1)
cache.update(update4, update4, 0)
cache.update(update4_1, update4_1, 1)
self.assertEqual(
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed"
)
self.assertEqual(
cache.layers[1].keys[0, 0, :, 0].tolist(),
[10.0, 20.0, 30.0, 40.0],
"DynamicCache Scenario 2 layer 1 failed",
)
def test_hybrid_cache(self):
"""
Test HybridCache with a mix of static and sliding layers,
with prefill size bigger than sliding window.
prefill:
static: [1.0, 2.0, 3.0]
sliding: [10.0, 20.0, 30.0]
(stores only [20.0, 30.0])
update pos 4:
static: [1.0, 2.0, 3.0, 5.0]
sliding: [30.0, 50.0]
"""
config = copy.deepcopy(self.config)
config.num_hidden_layers = 2
config.layer_types = ["full_attention", "sliding_attention"]
config.sliding_window = 2
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
# Prefill both layers up to cache capacity
prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None]
# Sliding window is 2, so it should return full [10.0, 20.0, 30.0], but store only [20.0, 30.0]
prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None]
# Update static layer (layer 0)
res_static = hybrid_cache.update(
key_states=prefill_static,
value_states=prefill_static,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(3)},
)
# Update sliding layer (layer 1)
res_sliding = hybrid_cache.update(
key_states=prefill_sliding,
value_states=prefill_sliding,
layer_idx=1,
cache_kwargs={"cache_position": torch.arange(3), "sliding_window": self.window_size},
)
# Verify initial states
self.assertEqual(
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"Initial static layer state is wrong",
)
self.assertEqual(
res_static[0][0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"Static layer did not return the correct value.",
)
self.assertEqual(
hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(),
[20.0, 30.0],
"Initial sliding layer state is wrong",
)
self.assertEqual(
res_sliding[0][0, 0, :, 0].tolist(),
[10.0, 20.0, 30.0],
"Sliding layer did not return the correct value.",
)
# Update at position 4
new_key_static = torch.tensor(5.0)[None, None, None, None]
new_key_sliding = torch.tensor(50.0)[None, None, None, None]
# Update static layer (layer 0)
hybrid_cache.update(
key_states=new_key_static,
value_states=new_key_static,
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([3])},
)
# Update sliding layer (layer 1)
hybrid_cache.update(
key_states=new_key_sliding,
value_states=new_key_sliding,
layer_idx=1,
cache_kwargs={"cache_position": torch.tensor([3])},
)
# The static layer does not slide, so it should have updated the element at position 3
self.assertEqual(
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 5.0],
"Static layer did not update as expected.",
)
# The sliding layer should have shifted, discarding the first element and adding the new one at the end
self.assertEqual(
hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(),
[30.0, 50.0],
"Sliding layer did not slide as expected.",
)
def test_hybrid_chunked_cache(self):
"""
Test HybridChunkedCache with both static and sliding layers and special cases:
1. a pre-fill longer than the sliding window
2. a single-token decoding step (normal generation)
3. a multi-token decoding step after the window is already full
Sliding-window size: 2
Static layer is full-attention.
─────────────────────────────────────────────
Prefill:
static : [1, 2, 3]
sliding : [10, 20, 30] (cache keeps [20, 30])
+1 token:
static : [1, 2, 3, 5]
sliding : [30, 50] (returned [30, 50])
+2 tokens:
sliding : [60, 70] (returned [50, 60, 70])
"""
config = copy.deepcopy(self.config)
config.num_hidden_layers = 2
config.layer_types = ["full_attention", "sliding_attention"]
config.sliding_window = 2
max_cache_len = 4
chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len)
# 1) PREFILL (3 tokens > sliding_window)
prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None]
prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None]
res_static = chunked_cache.update(
key_states=prefill_static,
value_states=prefill_static,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(3)},
)
res_sliding = chunked_cache.update(
key_states=prefill_sliding,
value_states=prefill_sliding,
layer_idx=1,
cache_kwargs={"cache_position": torch.arange(3)},
)
# Static layer keeps everything
self.assertEqual(res_static[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0])
# Sliding layer returned full prompt but stored the tail
self.assertEqual(res_sliding[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0])
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [20.0, 30.0])
# 2) ONE-TOKEN UPDATE (normal decode)
new_static = torch.tensor(5.0)[None, None, None, None]
new_sliding = torch.tensor(50.0)[None, None, None, None]
chunked_cache.update(
key_states=new_static,
value_states=new_static,
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([3])},
)
res_one = chunked_cache.update(
key_states=new_sliding,
value_states=new_sliding,
layer_idx=1,
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(chunked_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 5.0])
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [30.0, 50.0])
self.assertEqual(res_one[0][0, 0, :, 0].tolist(), [30.0, 50.0])
# 3) TWO-TOKEN UPDATE after window is full
new_sliding_2 = torch.tensor([60.0, 70.0])[None, None, :, None]
res_two = chunked_cache.update(
key_states=new_sliding_2,
value_states=new_sliding_2,
layer_idx=1,
cache_kwargs={"cache_position": torch.tensor([4, 5])}, # arbitrary positions; ignored in full mode
)
# Cache now keeps the latest two tokens
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [60.0, 70.0])
# Returned tensor contains previous last token + new ones
self.assertEqual(res_two[0][0, 0, :, 0].tolist(), [50.0, 60.0, 70.0])
def test_hybrid_chunked_cache_extra_cases(self):
"""
Covers the new cases that appear on prefill chunking:
1) Not full multi-token update (cache_position[0] + update_len <= max_cache_len)
2) Multi-token update crossing the window (cache_position[0] < max_cache_len and cache_position[0] + update_len > max_cache_len)
Single sliding layer, max_cache_len = 3.
Step 0 (prefill 2 tokens, update_len < max_cache_len
cache = [10, 20, 0] returned [10, 20, 0]
Step 1 (add 2 tokens, p = 2, update_len = 2, p + update_len = 4 > max_cache_len)
cache = [20, 30, 40] returned [10, 20, 30, 40]
"""
config = copy.deepcopy(self.config)
config.num_hidden_layers = 1
config.layer_types = ["sliding_attention"]
config.sliding_window = 3
cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3)
# Step 0 : multi-token prefill
first_chunk = torch.tensor([10.0, 20.0])[None, None, :, None] # L = 2
returned_0 = cache.update(
key_states=first_chunk,
value_states=first_chunk,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(2)}, # p = 0,1
)
# internal cache should have first two tokens and a zero pad
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [10.0, 20.0, 0.0])
self.assertEqual(returned_0[0][0, 0, :, 0].tolist(), [10.0, 20.0, 0.0])
# Step 1 : multi-token update crossing the window boundary
second_chunk = torch.tensor([30.0, 40.0])[None, None, :, None] # L = 2
returned_1 = cache.update(
key_states=second_chunk,
value_states=second_chunk,
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([2, 3])}, # p = 2
)
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [20.0, 30.0, 40.0])
self.assertEqual(returned_1[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0])