New cache tests and refactored Hybrid Cache (#37972)

This commit is contained in:
Manuel de Prada Corral
2025-05-20 12:46:13 +02:00
committed by GitHub
parent 183fb3637c
commit d34e21e7dd
2 changed files with 471 additions and 147 deletions

View File

@@ -46,10 +46,14 @@ if is_torch_available():
Cache,
ClvpForCausalLM,
DynamicCache,
Gemma2Config,
GenerationConfig,
HybridCache,
LlamaConfig,
SlidingWindowCache,
StaticCache,
convert_and_export_with_cache,
pipeline,
)
@@ -188,6 +192,21 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
# Installed dependencies
if cache_implementation == "quantized" and not is_optimum_quanto_available():
test.skipTest("Quanto is not available")
# Devices
if "offloaded" in cache_implementation:
has_accelerator = torch_device is not None and torch_device != "cpu"
if not has_accelerator:
test.skipTest("Offloaded caches require an accelerator")
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
if backend_device_count(torch_device) != 1:
test.skipTest("Offloaded static caches require exactly 1 accelerator")
class CacheIntegrationTest(unittest.TestCase):
"""Fast cache integration tests that share the same small model"""
@@ -200,24 +219,10 @@ class CacheIntegrationTest(unittest.TestCase):
)
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
def _skip_on_failed_cache_prerequisites(self, cache_implementation):
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
# Installed dependencies
if cache_implementation == "quantized" and not is_optimum_quanto_available():
self.skipTest("Quanto is not available")
# Devices
if "offloaded" in cache_implementation:
has_accelerator = torch_device is not None and torch_device != "cpu"
if not has_accelerator:
self.skipTest("Offloaded caches require an accelerator")
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
if backend_device_count(torch_device) != 1:
self.skipTest("Offloaded static caches require exactly 1 accelerator")
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_batched(self, cache_implementation):
"""Sanity check: caches' `.update` function expects batched inputs"""
self._skip_on_failed_cache_prerequisites(cache_implementation)
_skip_on_failed_cache_prerequisites(self, cache_implementation)
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
@@ -246,7 +251,7 @@ class CacheIntegrationTest(unittest.TestCase):
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
(an output sequence contains multiple beam indices).
"""
self._skip_on_failed_cache_prerequisites(cache_implementation)
_skip_on_failed_cache_prerequisites(self, cache_implementation)
if cache_implementation == "offloaded_hybrid_chunked":
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
@@ -280,7 +285,7 @@ class CacheIntegrationTest(unittest.TestCase):
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the cache"""
self._skip_on_failed_cache_prerequisites(cache_implementation)
_skip_on_failed_cache_prerequisites(self, cache_implementation)
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
@@ -552,6 +557,28 @@ class CacheHardIntegrationTest(unittest.TestCase):
_ = model(**inputs)
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
@require_torch_gpu
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_gptj_model(self, cache_implementation):
"""Tests caches with GPT-J model. Regression test for https://github.com/huggingface/transformers/pull/34799"""
_skip_on_failed_cache_prerequisites(self, cache_implementation)
model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM"
pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16)
pipe.model.config.sliding_window = (
256 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None
)
out = pipe(
"hello world",
cache_implementation=cache_implementation,
max_new_tokens=10,
do_sample=False,
disable_compile=True,
return_tensors=True,
)[0]["generated_token_ids"][-10:]
EXPECTED_OUTPUT = [879, 175, 39, 141, 1000, 975, 951, 991, 683, 441]
self.assertListEqual(out, EXPECTED_OUTPUT)
@require_torch
class CacheExportIntegrationTest(unittest.TestCase):
@@ -721,3 +748,276 @@ class CacheExportIntegrationTest(unittest.TestCase):
dynamic_shapes=dynamic_shapes,
strict=False,
)
class SyntheticCacheTest(unittest.TestCase):
"""Tests cache behavior with simple dummy data."""
def setUp(self):
"""Set up common configuration and cache instances for all tests."""
self.window_size = 4
self.max_cache_len = 4
self.config = Gemma2Config(
num_hidden_layers=1,
num_key_value_heads=1,
num_attention_heads=1,
head_dim=1,
hidden_size=1,
sliding_window=self.window_size,
sliding_window_pattern=2, # Default pattern for hybrid sliding
)
def test_static_cache_out_of_bounds(self):
"""Test StaticCache raises IndexError for out-of-bounds positions."""
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len
with self.assertRaises(IndexError):
static_cache.update(
key_states=torch.tensor([[[[1.0]]]]),
value_states=torch.tensor([[[[1.0]]]]),
layer_idx=0,
cache_kwargs={"cache_position": pos_out_of_bounds},
)
def test_static_cache(self):
"""Test StaticCache with manually prefilled states and hardcoded assertions.
Scenario 1: Fill up to near capacity
prefill: [1.0, 2.0, 0.0, 0.0]
update pos 2: [1.0, 2.0, 3.0, 0.0]
Scenario 2: Fill to capacity
update pos 3: [1.0, 2.0, 3.0, 4.0]
"""
# Scenario 1: Fill up to near capacity
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None)
static_cache.update(
key_states=torch.tensor(3.0)[None, None, None, None],
value_states=torch.tensor(3.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([2])},
)
self.assertEqual(
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
)
# Scenario 2: Fill to capacity
static_cache.update(
key_states=torch.tensor(4.0)[None, None, None, None],
value_states=torch.tensor(4.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
)
def test_sliding_window_cache(self):
"""Test SlidingWindowCache with manually prefilled states and hardcoded assertions.
Scenario 1: Update within window, no slide yet
prefill: [1.0, 2.0, 0.0, 0.0]
update pos 2: [1.0, 2.0, 3.0, 0.0]
Scenario 2: Update causing slide
prefill: [1.0, 2.0, 3.0, 4.0]
update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)
Scenario 3: Long prompt handling (prompt_len > window_size)
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
"""
# Scenario 1: Update within window, no slide yet
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
sliding_cache.update(
key_states=prefill,
value_states=prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
)
sliding_cache.update(
key_states=torch.tensor(3.0)[None, None, None, None],
value_states=torch.tensor(3.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"SlidingWindowCache Scenario 1 failed",
)
# Scenario 2: Update causing slide
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
sliding_cache.update(
key_states=prefill,
value_states=prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
)
sliding_cache.update(
key_states=torch.tensor(5.0)[None, None, None, None],
value_states=torch.tensor(5.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
[2.0, 3.0, 4.0, 5.0],
"SlidingWindowCache Scenario 2 failed",
)
# Scenario 3: Long prompt handling
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
sliding_cache.update(
key_states=long_prefill,
value_states=long_prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
)
self.assertEqual(
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"SlidingWindowCache Scenario 3 failed",
)
def test_hybrid_cache_static_mode(self):
"""Test HybridCache in static mode with hardcoded assertions.
Scenario 1: Static layer behavior
prefill: [1.0, 2.0, 0.0, 0.0]
update pos 2: [1.0, 2.0, 3.0, 0.0]
Scenario 2: Fill to capacity
update pos 3: [1.0, 2.0, 3.0, 4.0]
"""
config = copy.deepcopy(self.config)
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
# Scenario 1
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
hybrid_cache_static_mode.update(
key_states=prefill,
value_states=prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(4)},
)
hybrid_cache_static_mode.update(
key_states=torch.tensor(3.0)[None, None, None, None],
value_states=torch.tensor(3.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([2])},
)
self.assertEqual(
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"HybridCache Static Scenario 1 failed",
)
# Scenario 2
hybrid_cache_static_mode.update(
key_states=torch.tensor(4.0)[None, None, None, None],
value_states=torch.tensor(4.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([3])},
)
self.assertEqual(
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 4.0],
"HybridCache Static Scenario 2 failed",
)
def test_hybrid_cache_sliding_mode(self):
"""Test HybridCache in sliding mode with hardcoded assertions.
Scenario 1: Update within window, no slide yet
prefill: [1.0, 2.0, 0.0, 0.0]
update pos 2: [1.0, 2.0, 3.0, 0.0]
Scenario 2: Update causing first slide
prefill: [1.0, 2.0, 3.0, 4.0]
update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)
Scenario 3: Update causing subsequent slide
update pos 5: [3.0, 4.0, 5.0, 6.0] (shift continues)
Scenario 4: Long prompt handling (prompt_len > window_size)
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
"""
# Scenario 1: Update within window, no slide yet
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
hybrid_cache.update(
key_states=prefill,
value_states=prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
)
hybrid_cache.update(
key_states=torch.tensor(3.0)[None, None, None, None],
value_states=torch.tensor(3.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
[1.0, 2.0, 3.0, 0.0],
"HybridCache Sliding Scenario 1 failed",
)
# Scenario 2: Update causing first slide
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
hybrid_cache.update(
key_states=prefill,
value_states=prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
)
hybrid_cache.update(
key_states=torch.tensor(5.0)[None, None, None, None],
value_states=torch.tensor(5.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
[2.0, 3.0, 4.0, 5.0],
"HybridCache Sliding Scenario 2 failed",
)
# Scenario 3: Update causing subsequent slide
hybrid_cache.update(
key_states=torch.tensor(6.0)[None, None, None, None],
value_states=torch.tensor(6.0)[None, None, None, None],
layer_idx=0,
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"HybridCache Sliding Scenario 3 failed",
)
# Scenario 4: Long prompt handling
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
hybrid_cache.update(
key_states=long_prefill,
value_states=long_prefill,
layer_idx=0,
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
)
self.assertEqual(
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
[3.0, 4.0, 5.0, 6.0],
"HybridCache Sliding Scenario 4 failed",
)