New cache tests and refactored Hybrid Cache (#37972)
This commit is contained in:
committed by
GitHub
parent
183fb3637c
commit
d34e21e7dd
@@ -46,10 +46,14 @@ if is_torch_available():
|
||||
Cache,
|
||||
ClvpForCausalLM,
|
||||
DynamicCache,
|
||||
Gemma2Config,
|
||||
GenerationConfig,
|
||||
HybridCache,
|
||||
LlamaConfig,
|
||||
SlidingWindowCache,
|
||||
StaticCache,
|
||||
convert_and_export_with_cache,
|
||||
pipeline,
|
||||
)
|
||||
|
||||
|
||||
@@ -188,6 +192,21 @@ class CacheTest(unittest.TestCase):
|
||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||
|
||||
|
||||
def _skip_on_failed_cache_prerequisites(test, cache_implementation):
|
||||
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
||||
# Installed dependencies
|
||||
if cache_implementation == "quantized" and not is_optimum_quanto_available():
|
||||
test.skipTest("Quanto is not available")
|
||||
# Devices
|
||||
if "offloaded" in cache_implementation:
|
||||
has_accelerator = torch_device is not None and torch_device != "cpu"
|
||||
if not has_accelerator:
|
||||
test.skipTest("Offloaded caches require an accelerator")
|
||||
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
|
||||
if backend_device_count(torch_device) != 1:
|
||||
test.skipTest("Offloaded static caches require exactly 1 accelerator")
|
||||
|
||||
|
||||
class CacheIntegrationTest(unittest.TestCase):
|
||||
"""Fast cache integration tests that share the same small model"""
|
||||
|
||||
@@ -200,24 +219,10 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
|
||||
|
||||
def _skip_on_failed_cache_prerequisites(self, cache_implementation):
|
||||
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
||||
# Installed dependencies
|
||||
if cache_implementation == "quantized" and not is_optimum_quanto_available():
|
||||
self.skipTest("Quanto is not available")
|
||||
# Devices
|
||||
if "offloaded" in cache_implementation:
|
||||
has_accelerator = torch_device is not None and torch_device != "cpu"
|
||||
if not has_accelerator:
|
||||
self.skipTest("Offloaded caches require an accelerator")
|
||||
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
|
||||
if backend_device_count(torch_device) != 1:
|
||||
self.skipTest("Offloaded static caches require exactly 1 accelerator")
|
||||
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_batched(self, cache_implementation):
|
||||
"""Sanity check: caches' `.update` function expects batched inputs"""
|
||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
||||
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
||||
|
||||
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||
|
||||
@@ -246,7 +251,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
|
||||
(an output sequence contains multiple beam indices).
|
||||
"""
|
||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
||||
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
||||
if cache_implementation == "offloaded_hybrid_chunked":
|
||||
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
|
||||
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
|
||||
@@ -280,7 +285,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_extra_left_padding(self, cache_implementation):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the cache"""
|
||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
||||
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
||||
|
||||
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
|
||||
|
||||
@@ -552,6 +557,28 @@ class CacheHardIntegrationTest(unittest.TestCase):
|
||||
_ = model(**inputs)
|
||||
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_gptj_model(self, cache_implementation):
|
||||
"""Tests caches with GPT-J model. Regression test for https://github.com/huggingface/transformers/pull/34799"""
|
||||
_skip_on_failed_cache_prerequisites(self, cache_implementation)
|
||||
|
||||
model_id = "hf-internal-testing/tiny-random-GPTJForCausalLM"
|
||||
pipe = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16)
|
||||
pipe.model.config.sliding_window = (
|
||||
256 if cache_implementation in ["sliding_window", "hybrid", "hybrid_chunked"] else None
|
||||
)
|
||||
out = pipe(
|
||||
"hello world",
|
||||
cache_implementation=cache_implementation,
|
||||
max_new_tokens=10,
|
||||
do_sample=False,
|
||||
disable_compile=True,
|
||||
return_tensors=True,
|
||||
)[0]["generated_token_ids"][-10:]
|
||||
EXPECTED_OUTPUT = [879, 175, 39, 141, 1000, 975, 951, 991, 683, 441]
|
||||
self.assertListEqual(out, EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch
|
||||
class CacheExportIntegrationTest(unittest.TestCase):
|
||||
@@ -721,3 +748,276 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
|
||||
class SyntheticCacheTest(unittest.TestCase):
|
||||
"""Tests cache behavior with simple dummy data."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up common configuration and cache instances for all tests."""
|
||||
self.window_size = 4
|
||||
self.max_cache_len = 4
|
||||
self.config = Gemma2Config(
|
||||
num_hidden_layers=1,
|
||||
num_key_value_heads=1,
|
||||
num_attention_heads=1,
|
||||
head_dim=1,
|
||||
hidden_size=1,
|
||||
sliding_window=self.window_size,
|
||||
sliding_window_pattern=2, # Default pattern for hybrid sliding
|
||||
)
|
||||
|
||||
def test_static_cache_out_of_bounds(self):
|
||||
"""Test StaticCache raises IndexError for out-of-bounds positions."""
|
||||
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
pos_out_of_bounds = torch.tensor([self.max_cache_len]) # Position >= max_cache_len
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
static_cache.update(
|
||||
key_states=torch.tensor([[[[1.0]]]]),
|
||||
value_states=torch.tensor([[[[1.0]]]]),
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": pos_out_of_bounds},
|
||||
)
|
||||
|
||||
def test_static_cache(self):
|
||||
"""Test StaticCache with manually prefilled states and hardcoded assertions.
|
||||
|
||||
Scenario 1: Fill up to near capacity
|
||||
prefill: [1.0, 2.0, 0.0, 0.0]
|
||||
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
||||
|
||||
Scenario 2: Fill to capacity
|
||||
update pos 3: [1.0, 2.0, 3.0, 4.0]
|
||||
"""
|
||||
# Scenario 1: Fill up to near capacity
|
||||
static_cache = StaticCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
static_cache.update(key_states=prefill, value_states=prefill, layer_idx=0, cache_kwargs=None)
|
||||
static_cache.update(
|
||||
key_states=torch.tensor(3.0)[None, None, None, None],
|
||||
value_states=torch.tensor(3.0)[None, None, None, None],
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([2])},
|
||||
)
|
||||
self.assertEqual(
|
||||
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 0.0], "StaticCache Scenario 1 failed"
|
||||
)
|
||||
|
||||
# Scenario 2: Fill to capacity
|
||||
static_cache.update(
|
||||
key_states=torch.tensor(4.0)[None, None, None, None],
|
||||
value_states=torch.tensor(4.0)[None, None, None, None],
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
self.assertEqual(
|
||||
static_cache.key_cache[0][0, 0, :, 0].tolist(), [1.0, 2.0, 3.0, 4.0], "StaticCache Scenario 2 failed"
|
||||
)
|
||||
|
||||
def test_sliding_window_cache(self):
|
||||
"""Test SlidingWindowCache with manually prefilled states and hardcoded assertions.
|
||||
|
||||
Scenario 1: Update within window, no slide yet
|
||||
prefill: [1.0, 2.0, 0.0, 0.0]
|
||||
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
||||
|
||||
Scenario 2: Update causing slide
|
||||
prefill: [1.0, 2.0, 3.0, 4.0]
|
||||
update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)
|
||||
|
||||
Scenario 3: Long prompt handling (prompt_len > window_size)
|
||||
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
||||
"""
|
||||
# Scenario 1: Update within window, no slide yet
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=prefill,
|
||||
value_states=prefill,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
||||
)
|
||||
sliding_cache.update(
|
||||
key_states=torch.tensor(3.0)[None, None, None, None],
|
||||
value_states=torch.tensor(3.0)[None, None, None, None],
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"SlidingWindowCache Scenario 1 failed",
|
||||
)
|
||||
|
||||
# Scenario 2: Update causing slide
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=prefill,
|
||||
value_states=prefill,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
||||
)
|
||||
sliding_cache.update(
|
||||
key_states=torch.tensor(5.0)[None, None, None, None],
|
||||
value_states=torch.tensor(5.0)[None, None, None, None],
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
[2.0, 3.0, 4.0, 5.0],
|
||||
"SlidingWindowCache Scenario 2 failed",
|
||||
)
|
||||
|
||||
# Scenario 3: Long prompt handling
|
||||
sliding_cache = SlidingWindowCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
||||
sliding_cache.update(
|
||||
key_states=long_prefill,
|
||||
value_states=long_prefill,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
sliding_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"SlidingWindowCache Scenario 3 failed",
|
||||
)
|
||||
|
||||
def test_hybrid_cache_static_mode(self):
|
||||
"""Test HybridCache in static mode with hardcoded assertions.
|
||||
|
||||
Scenario 1: Static layer behavior
|
||||
prefill: [1.0, 2.0, 0.0, 0.0]
|
||||
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
||||
|
||||
Scenario 2: Fill to capacity
|
||||
update pos 3: [1.0, 2.0, 3.0, 4.0]
|
||||
"""
|
||||
config = copy.deepcopy(self.config)
|
||||
config.sliding_window_pattern = 1 # Layer 0 is static (1 % 1 == 0)
|
||||
|
||||
# Scenario 1
|
||||
hybrid_cache_static_mode = HybridCache(config=config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
hybrid_cache_static_mode.update(
|
||||
key_states=prefill,
|
||||
value_states=prefill,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(4)},
|
||||
)
|
||||
hybrid_cache_static_mode.update(
|
||||
key_states=torch.tensor(3.0)[None, None, None, None],
|
||||
value_states=torch.tensor(3.0)[None, None, None, None],
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([2])},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"HybridCache Static Scenario 1 failed",
|
||||
)
|
||||
|
||||
# Scenario 2
|
||||
hybrid_cache_static_mode.update(
|
||||
key_states=torch.tensor(4.0)[None, None, None, None],
|
||||
value_states=torch.tensor(4.0)[None, None, None, None],
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([3])},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache_static_mode.key_cache[0][0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
"HybridCache Static Scenario 2 failed",
|
||||
)
|
||||
|
||||
def test_hybrid_cache_sliding_mode(self):
|
||||
"""Test HybridCache in sliding mode with hardcoded assertions.
|
||||
|
||||
Scenario 1: Update within window, no slide yet
|
||||
prefill: [1.0, 2.0, 0.0, 0.0]
|
||||
update pos 2: [1.0, 2.0, 3.0, 0.0]
|
||||
|
||||
Scenario 2: Update causing first slide
|
||||
prefill: [1.0, 2.0, 3.0, 4.0]
|
||||
update pos 4: [2.0, 3.0, 4.0, 5.0] (shift happens as pos > window_size-1)
|
||||
|
||||
Scenario 3: Update causing subsequent slide
|
||||
update pos 5: [3.0, 4.0, 5.0, 6.0] (shift continues)
|
||||
|
||||
Scenario 4: Long prompt handling (prompt_len > window_size)
|
||||
input: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
result: [3.0, 4.0, 5.0, 6.0] (keeps last window_size tokens)
|
||||
"""
|
||||
# Scenario 1: Update within window, no slide yet
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 0.0, 0.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=prefill,
|
||||
value_states=prefill,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
||||
)
|
||||
hybrid_cache.update(
|
||||
key_states=torch.tensor(3.0)[None, None, None, None],
|
||||
value_states=torch.tensor(3.0)[None, None, None, None],
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([2]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
[1.0, 2.0, 3.0, 0.0],
|
||||
"HybridCache Sliding Scenario 1 failed",
|
||||
)
|
||||
|
||||
# Scenario 2: Update causing first slide
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
prefill = torch.tensor([1.0, 2.0, 3.0, 4.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=prefill,
|
||||
value_states=prefill,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(4), "sliding_window": self.window_size},
|
||||
)
|
||||
hybrid_cache.update(
|
||||
key_states=torch.tensor(5.0)[None, None, None, None],
|
||||
value_states=torch.tensor(5.0)[None, None, None, None],
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([4]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
[2.0, 3.0, 4.0, 5.0],
|
||||
"HybridCache Sliding Scenario 2 failed",
|
||||
)
|
||||
|
||||
# Scenario 3: Update causing subsequent slide
|
||||
hybrid_cache.update(
|
||||
key_states=torch.tensor(6.0)[None, None, None, None],
|
||||
value_states=torch.tensor(6.0)[None, None, None, None],
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.tensor([5]), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"HybridCache Sliding Scenario 3 failed",
|
||||
)
|
||||
|
||||
# Scenario 4: Long prompt handling
|
||||
hybrid_cache = HybridCache(config=self.config, max_batch_size=1, max_cache_len=self.max_cache_len)
|
||||
long_prefill = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])[None, None, :, None]
|
||||
hybrid_cache.update(
|
||||
key_states=long_prefill,
|
||||
value_states=long_prefill,
|
||||
layer_idx=0,
|
||||
cache_kwargs={"cache_position": torch.arange(6), "sliding_window": self.window_size},
|
||||
)
|
||||
self.assertEqual(
|
||||
hybrid_cache.key_cache[0][0, 0, :, 0].tolist(),
|
||||
[3.0, 4.0, 5.0, 6.0],
|
||||
"HybridCache Sliding Scenario 4 failed",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user