[cache refactor] Move all the caching logic to a per-layer approach (#39106)
* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc2737806bf849ce2148cb1e66d59573913. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
This commit is contained in:
committed by
GitHub
parent
b16688e96a
commit
c338fd43b0
@@ -36,7 +36,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal
|
||||
from transformers.utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -49,8 +49,12 @@ if is_torch_available():
|
||||
DynamicCache,
|
||||
Gemma2Config,
|
||||
GenerationConfig,
|
||||
HQQQuantizedCacheProcessor,
|
||||
HybridCache,
|
||||
HybridChunkedCache,
|
||||
LlamaConfig,
|
||||
QuantizedCache,
|
||||
QuantoQuantizedCacheProcessor,
|
||||
SlidingWindowCache,
|
||||
StaticCache,
|
||||
convert_and_export_with_cache,
|
||||
@@ -252,6 +256,59 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@parameterized.expand([("quanto"), ("HQQ")])
|
||||
def test_quantized_cache_generation(self, backend):
|
||||
"""Tests that QuantizedCache works as expected for both `quanto` and `hqq` backends."""
|
||||
if backend == "quanto":
|
||||
if not is_optimum_quanto_available():
|
||||
self.skipTest("Quanto is not available")
|
||||
axis_key, axis_value = 0, 0
|
||||
# This output is taken from a run with the same parameters, and is known to be correct
|
||||
expected_generation = ["The cat's whiskers are also a sign of anxiety."]
|
||||
elif backend == "HQQ":
|
||||
if not is_hqq_available():
|
||||
self.skipTest("HQQ is not available")
|
||||
axis_key, axis_value = 1, 1
|
||||
# HQQ has slightly different numerics
|
||||
expected_generation = ["The cat's whiskers are also a sign of anxiety."]
|
||||
else:
|
||||
return
|
||||
|
||||
inputs = self.tokenizer(["The cat"], return_tensors="pt").to(self.model.device)
|
||||
|
||||
gen_out = self.model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=10,
|
||||
return_dict_in_generate=True,
|
||||
cache_implementation="quantized",
|
||||
cache_config={
|
||||
"backend": backend,
|
||||
"nbits": 4,
|
||||
"q_group_size": 16,
|
||||
"residual_length": 4,
|
||||
"axis_key": axis_key,
|
||||
"axis_value": axis_value,
|
||||
},
|
||||
disable_compile=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(gen_out.past_key_values, QuantizedCache)
|
||||
processor = gen_out.past_key_values.cache_processor
|
||||
if backend == "quanto":
|
||||
self.assertIsInstance(processor, QuantoQuantizedCacheProcessor)
|
||||
elif backend == "hqq":
|
||||
self.assertIsInstance(processor, HQQQuantizedCacheProcessor)
|
||||
|
||||
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, expected_generation)
|
||||
|
||||
self.assertTrue(len(processor._quantized_keys) > 0)
|
||||
|
||||
# Check that something is actually quantized
|
||||
has_been_quantized = any((q[0] if isinstance(q, tuple) else q).numel() > 0 for q in processor._quantized_keys)
|
||||
self.assertTrue(has_been_quantized)
|
||||
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_extra_left_padding(self, cache_implementation):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the cache"""
|
||||
@@ -566,7 +623,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
past_key_values=DynamicCache(),
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
||||
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
|
||||
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
||||
self.assertEqual(
|
||||
3,
|
||||
@@ -587,11 +644,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5))
|
||||
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
|
||||
|
||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
|
||||
for l1, l2 in zip(res.past_key_values.layers, res_eager.past_key_values.layers):
|
||||
self.assertTrue(torch.allclose(l1.keys, l2.keys, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(l1.values, l2.values, atol=1e-5))
|
||||
|
||||
def test_dynamic_cache_exportability_multiple_run(self):
|
||||
# When exporting with DynamicCache, you should export two graphs:
|
||||
@@ -615,7 +670,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
past_key_values=DynamicCache(),
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
|
||||
self.assertTrue(len(res.past_key_values) == model.config.num_hidden_layers)
|
||||
self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
|
||||
self.assertEqual(
|
||||
3,
|
||||
@@ -640,9 +695,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
shapes = torch.export.ShapesCollection()
|
||||
dyn = torch.export.Dim("seq", max=512)
|
||||
|
||||
for ix in range(len(past_key_values.key_cache)):
|
||||
shapes[past_key_values.key_cache[ix]] = (None, None, dyn, None)
|
||||
shapes[past_key_values.value_cache[ix]] = (None, None, dyn, None)
|
||||
for ix in range(len(past_key_values)):
|
||||
shapes[past_key_values.layers[ix].keys] = (None, None, dyn, None)
|
||||
shapes[past_key_values.layers[ix].values] = (None, None, dyn, None)
|
||||
|
||||
ep_second = torch.export.export(
|
||||
model,
|
||||
@@ -683,11 +738,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
|
||||
|
||||
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
|
||||
for l1, l2 in zip(res_export_2.past_key_values.layers, res_eager_2.past_key_values.layers):
|
||||
self.assertTrue(torch.allclose(l1.keys, l2.keys, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(l1.values, l2.values, atol=1e-5))
|
||||
|
||||
@unittest.skip("Runs on my machine locally, passed, no idea why it does not online")
|
||||
def test_static_cache_exportability(self):
|
||||
@@ -726,8 +779,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
|
||||
self.assertEqual(model.generation_config.max_length, max_cache_len)
|
||||
self.assertTrue(model.generation_config.cache_config is not None)
|
||||
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
|
||||
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)
|
||||
self.assertEqual(model.generation_config.cache_config.get("batch_size"), batch_size)
|
||||
self.assertEqual(model.generation_config.cache_config.get("max_cache_len"), max_cache_len)
|
||||
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
|
||||
@@ -830,7 +883,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
head_dim=1,
|
||||
hidden_size=1,
|
||||
sliding_window=self.window_size,
|
||||
sliding_window_pattern=2, # Default pattern for hybrid sliding
|
||||
layer_types=["full_attention"] * 1, # Static cache by default
|
||||
)
|
||||
|
||||
def test_static_cache_out_of_bounds(self):
|
||||
@@ -867,7 +920,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2])},
|
||||
)
|
||||
self.assertEqual(
|
||||
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
|
||||
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
|
||||
)
|
||||
|
||||
# Scenario 2: Fill to capacity
|
||||
@@ -878,7 +931,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
self.assertEqual(
|
||||
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
|
||||
static_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
|
||||
)
|
||||
|
||||
def test_sliding_window_cache(self):
|
||||
@@ -897,7 +950,9 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
||||
"""
|
||||
# Scenario 1: Update within window, no slide yet
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
config = copy.deepcopy(self.config)
|
||||
config.layer_types = ["sliding_attention"] * config.num_hidden_layers
|
||||
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=prefill,
|
||||
@@ -912,13 +967,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"SlidingWindowCache Scenario 1 failed",
|
||||
)
|
||||
|
||||
# Scenario 2: Update causing slide
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=prefill,
|
||||
@@ -933,13 +988,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[2.0, 3.0, 4.0, 5.0],
|
||||
"SlidingWindowCache Scenario 2 failed",
|
||||
)
|
||||
|
||||
# Scenario 3: Long prompt handling
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
sliding_cache = SlidingWindowCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=long_prefill,
|
||||
@@ -948,13 +1003,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
sliding_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"SlidingWindowCache Scenario 3 failed",
|
||||
)
|
||||
|
||||
def test_hybrid_cache_static_mode(self):
|
||||
"""Test HybridCache in static mode with hardcoded assertions.
|
||||
"""Test HybridCache with only 1 static layer.
|
||||
|
||||
Scenario 1: Static layer behavior
|
||||
prefill: [1.0, 2.0, 0.0, 0.0]
|
||||
@@ -964,7 +1019,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
update pos 3: [1.0, 2.0, 3.0, 4.0]
|
||||
"""
|
||||
config = copy.deepcopy(self.config)
|
||||
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
|
||||
config.layer_types = ["full_attention"] * config.num_hidden_layers
|
||||
|
||||
# Scenario 1
|
||||
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
@@ -982,7 +1037,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2])},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"HybridCache Static Scenario 1 failed",
|
||||
)
|
||||
@@ -995,7 +1050,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache_static_mode.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
"HybridCache Static Scenario 2 failed",
|
||||
)
|
||||
@@ -1018,8 +1073,10 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
||||
"""
|
||||
config = copy.deepcopy(self.config)
|
||||
config.layer_types = ["sliding_attention"] * config.num_hidden_layers
|
||||
# Scenario 1: Update within window, no slide yet
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=prefill,
|
||||
@@ -1034,13 +1091,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"HybridCache Sliding Scenario 1 failed",
|
||||
)
|
||||
|
||||
# Scenario 2: Update causing first slide
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=prefill,
|
||||
@@ -1055,7 +1112,7 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[2.0, 3.0, 4.0, 5.0],
|
||||
"HybridCache Sliding Scenario 2 failed",
|
||||
)
|
||||
@@ -1068,13 +1125,13 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"HybridCache Sliding Scenario 3 failed",
|
||||
)
|
||||
|
||||
# Scenario 4: Long prompt handling
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=long_prefill,
|
||||
@@ -1083,7 +1140,278 @@ class SyntheticCacheTest(unittest.TestCase):
|
||||
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"HybridCache Sliding Scenario 4 failed",
|
||||
)
|
||||
|
||||
def test_dynamic_cache(self):
|
||||
"""Test DynamicCache with manually prefilled states and hardcoded assertions.
|
||||
Scenario 1: prefill and update for one layer
|
||||
prefill: [1.0, 2.0]
|
||||
update pos 2: [1.0, 2.0, 3.0]
|
||||
Scenario 2: prefill and update for two layers independently
|
||||
"""
|
||||
prefill = torch.tensor([1.0, 2.0])[None, None, :, None]
|
||||
update3 = torch.tensor(3.0)[None, None, None, None]
|
||||
update4 = torch.tensor(4.0)[None, None, None, None]
|
||||
|
||||
# Scenario 1: prefill and update for one layer
|
||||
cache = DynamicCache()
|
||||
cache.update(prefill, prefill, 0)
|
||||
cache.update(update3, update3, 0)
|
||||
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0], "DynamicCache Scenario 1 failed")
|
||||
cache.update(update4, update4, 0)
|
||||
self.assertEqual(
|
||||
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 1 (to 4) failed"
|
||||
)
|
||||
|
||||
# Scenario 2: prefill and update for two layers independently
|
||||
prefill1 = torch.tensor([10.0, 20.0])[None, None, :, None]
|
||||
update3_1 = torch.tensor(30.0)[None, None, None, None]
|
||||
update4_1 = torch.tensor(40.0)[None, None, None, None]
|
||||
|
||||
cache = DynamicCache()
|
||||
cache.update(prefill, prefill, 0)
|
||||
cache.update(prefill1, prefill1, 1)
|
||||
|
||||
cache.update(update3, update3, 0)
|
||||
cache.update(update3_1, update3_1, 1)
|
||||
cache.update(update4, update4, 0)
|
||||
cache.update(update4_1, update4_1, 1)
|
||||
self.assertEqual(
|
||||
cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "DynamicCache Scenario 2 layer 0 failed"
|
||||
)
|
||||
self.assertEqual(
|
||||
cache.layers[1].keys[0, 0, :, 0].tolist(),
|
||||
[10.0, 20.0, 30.0, 40.0],
|
||||
"DynamicCache Scenario 2 layer 1 failed",
|
||||
)
|
||||
|
||||
def test_hybrid_cache(self):
|
||||
"""
|
||||
Test HybridCache with a mix of static and sliding layers,
|
||||
with prefill size bigger than sliding window.
|
||||
|
||||
prefill:
|
||||
static: [1.0, 2.0, 3.0]
|
||||
sliding: [10.0, 20.0, 30.0]
|
||||
(stores only [20.0, 30.0])
|
||||
|
||||
update pos 4:
|
||||
static: [1.0, 2.0, 3.0, 5.0]
|
||||
sliding: [30.0, 50.0]
|
||||
"""
|
||||
config = copy.deepcopy(self.config)
|
||||
config.num_hidden_layers = 2
|
||||
config.layer_types = ["full_attention", "sliding_attention"]
|
||||
config.sliding_window = 2
|
||||
hybrid_cache = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
|
||||
# Prefill both layers up to cache capacity
|
||||
prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None]
|
||||
# Sliding window is 2, so it should return full [10.0, 20.0, 30.0], but store only [20.0, 30.0]
|
||||
prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None]
|
||||
|
||||
# Update static layer (layer 0)
|
||||
res_static = hybrid_cache.update(
|
||||
key_states=prefill_static,
|
||||
value_states=prefill_static,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(3)},
|
||||
)
|
||||
|
||||
# Update sliding layer (layer 1)
|
||||
res_sliding = hybrid_cache.update(
|
||||
key_states=prefill_sliding,
|
||||
value_states=prefill_sliding,
|
||||
layer_idx=1,
|
||||
cache_kwargs={"cache_position": torch.arange(3), "sliding_window": self.window_size},
|
||||
)
|
||||
|
||||
# Verify initial states
|
||||
self.assertEqual(
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"Initial static layer state is wrong",
|
||||
)
|
||||
self.assertEqual(
|
||||
res_static[0][0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"Static layer did not return the correct value.",
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(),
|
||||
[20.0, 30.0],
|
||||
"Initial sliding layer state is wrong",
|
||||
)
|
||||
self.assertEqual(
|
||||
res_sliding[0][0, 0, :, 0].tolist(),
|
||||
[10.0, 20.0, 30.0],
|
||||
"Sliding layer did not return the correct value.",
|
||||
)
|
||||
|
||||
# Update at position 4
|
||||
new_key_static = torch.tensor(5.0)[None, None, None, None]
|
||||
new_key_sliding = torch.tensor(50.0)[None, None, None, None]
|
||||
|
||||
# Update static layer (layer 0)
|
||||
hybrid_cache.update(
|
||||
key_states=new_key_static,
|
||||
value_states=new_key_static,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
|
||||
# Update sliding layer (layer 1)
|
||||
hybrid_cache.update(
|
||||
key_states=new_key_sliding,
|
||||
value_states=new_key_sliding,
|
||||
layer_idx=1,
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
|
||||
# The static layer does not slide, so it should have updated the element at position 3
|
||||
self.assertEqual(
|
||||
hybrid_cache.layers[0].keys[0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 5.0],
|
||||
"Static layer did not update as expected.",
|
||||
)
|
||||
|
||||
# The sliding layer should have shifted, discarding the first element and adding the new one at the end
|
||||
self.assertEqual(
|
||||
hybrid_cache.layers[1].keys[0, 0, :, 0].tolist(),
|
||||
[30.0, 50.0],
|
||||
"Sliding layer did not slide as expected.",
|
||||
)
|
||||
|
||||
def test_hybrid_chunked_cache(self):
|
||||
"""
|
||||
Test HybridChunkedCache with both static and sliding layers and special cases:
|
||||
1. a pre-fill longer than the sliding window
|
||||
2. a single-token decoding step (normal generation)
|
||||
3. a multi-token decoding step after the window is already full
|
||||
|
||||
Sliding-window size: 2
|
||||
Static layer is full-attention.
|
||||
─────────────────────────────────────────────
|
||||
Prefill:
|
||||
static : [1, 2, 3]
|
||||
sliding : [10, 20, 30] (cache keeps [20, 30])
|
||||
+1 token:
|
||||
static : [1, 2, 3, 5]
|
||||
sliding : [30, 50] (returned [30, 50])
|
||||
+2 tokens:
|
||||
sliding : [60, 70] (returned [50, 60, 70])
|
||||
"""
|
||||
|
||||
config = copy.deepcopy(self.config)
|
||||
config.num_hidden_layers = 2
|
||||
config.layer_types = ["full_attention", "sliding_attention"]
|
||||
config.sliding_window = 2
|
||||
max_cache_len = 4
|
||||
chunked_cache = HybridChunkedCache(config=config, max_batch_size=1, max_cache_len=max_cache_len)
|
||||
|
||||
# 1) PREFILL (3 tokens > sliding_window)
|
||||
prefill_static = torch.tensor([1.0, 2.0, 3.0])[None, None, :, None]
|
||||
prefill_sliding = torch.tensor([10.0, 20.0, 30.0])[None, None, :, None]
|
||||
|
||||
res_static = chunked_cache.update(
|
||||
key_states=prefill_static,
|
||||
value_states=prefill_static,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(3)},
|
||||
)
|
||||
res_sliding = chunked_cache.update(
|
||||
key_states=prefill_sliding,
|
||||
value_states=prefill_sliding,
|
||||
layer_idx=1,
|
||||
cache_kwargs={"cache_position": torch.arange(3)},
|
||||
)
|
||||
|
||||
# Static layer keeps everything
|
||||
self.assertEqual(res_static[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0])
|
||||
# Sliding layer returned full prompt but stored the tail
|
||||
self.assertEqual(res_sliding[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0])
|
||||
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [20.0, 30.0])
|
||||
|
||||
# 2) ONE-TOKEN UPDATE (normal decode)
|
||||
new_static = torch.tensor(5.0)[None, None, None, None]
|
||||
new_sliding = torch.tensor(50.0)[None, None, None, None]
|
||||
|
||||
chunked_cache.update(
|
||||
key_states=new_static,
|
||||
value_states=new_static,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
res_one = chunked_cache.update(
|
||||
key_states=new_sliding,
|
||||
value_states=new_sliding,
|
||||
layer_idx=1,
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
|
||||
self.assertEqual(chunked_cache.layers[0].keys[0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 5.0])
|
||||
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [30.0, 50.0])
|
||||
self.assertEqual(res_one[0][0, 0, :, 0].tolist(), [30.0, 50.0])
|
||||
|
||||
# 3) TWO-TOKEN UPDATE after window is full
|
||||
new_sliding_2 = torch.tensor([60.0, 70.0])[None, None, :, None]
|
||||
res_two = chunked_cache.update(
|
||||
key_states=new_sliding_2,
|
||||
value_states=new_sliding_2,
|
||||
layer_idx=1,
|
||||
cache_kwargs={"cache_position": torch.tensor([4, 5])}, # arbitrary positions; ignored in full mode
|
||||
)
|
||||
|
||||
# Cache now keeps the latest two tokens
|
||||
self.assertEqual(chunked_cache.layers[1].keys[0, 0, :, 0].tolist(), [60.0, 70.0])
|
||||
# Returned tensor contains previous last token + new ones
|
||||
self.assertEqual(res_two[0][0, 0, :, 0].tolist(), [50.0, 60.0, 70.0])
|
||||
|
||||
def test_hybrid_chunked_cache_extra_cases(self):
|
||||
"""
|
||||
Covers the new cases that appear on prefill chunking:
|
||||
1) Not full multi-token update (cache_position[0] + update_len <= max_cache_len)
|
||||
2) Multi-token update crossing the window (cache_position[0] < max_cache_len and cache_position[0] + update_len > max_cache_len)
|
||||
|
||||
Single sliding layer, max_cache_len = 3.
|
||||
|
||||
Step 0 (prefill 2 tokens, update_len < max_cache_len
|
||||
cache = [10, 20, 0] returned [10, 20, 0]
|
||||
|
||||
Step 1 (add 2 tokens, p = 2, update_len = 2, p + update_len = 4 > max_cache_len)
|
||||
cache = [20, 30, 40] returned [10, 20, 30, 40]
|
||||
"""
|
||||
|
||||
config = copy.deepcopy(self.config)
|
||||
config.num_hidden_layers = 1
|
||||
config.layer_types = ["sliding_attention"]
|
||||
config.sliding_window = 3
|
||||
cache = HybridChunkedCache(config, max_batch_size=1, max_cache_len=3)
|
||||
|
||||
# Step 0 : multi-token prefill
|
||||
first_chunk = torch.tensor([10.0, 20.0])[None, None, :, None] # L = 2
|
||||
returned_0 = cache.update(
|
||||
key_states=first_chunk,
|
||||
value_states=first_chunk,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(2)}, # p = 0,1
|
||||
)
|
||||
|
||||
# internal cache should have first two tokens and a zero pad
|
||||
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [10.0, 20.0, 0.0])
|
||||
self.assertEqual(returned_0[0][0, 0, :, 0].tolist(), [10.0, 20.0, 0.0])
|
||||
|
||||
# Step 1 : multi-token update crossing the window boundary
|
||||
second_chunk = torch.tensor([30.0, 40.0])[None, None, :, None] # L = 2
|
||||
returned_1 = cache.update(
|
||||
key_states=second_chunk,
|
||||
value_states=second_chunk,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([2, 3])}, # p = 2
|
||||
)
|
||||
|
||||
self.assertEqual(cache.layers[0].keys[0, 0, :, 0].tolist(), [20.0, 30.0, 40.0])
|
||||
self.assertEqual(returned_1[0][0, 0, :, 0].tolist(), [10.0, 20.0, 30.0, 40.0])
|
||||
|
||||
Reference in New Issue
Block a user