[tests] Test all cache implementations (#37873)

This commit is contained in:
Joao Gante
2025-04-30 15:37:00 +01:00
committed by GitHub
parent 2c1155519f
commit 1b222903c3
45 changed files with 338 additions and 438 deletions

View File

@@ -18,13 +18,12 @@ import unittest
from parameterized import parameterized
from transformers import set_seed
from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS
from transformers.testing_utils import (
CaptureStderr,
cleanup,
get_gpu_count,
is_torch_available,
require_gptq,
require_non_xpu,
require_torch,
require_torch_accelerator,
require_torch_gpu,
@@ -32,6 +31,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal
if is_torch_available():
@@ -40,15 +40,24 @@ if is_torch_available():
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Cache,
ClvpForCausalLM,
DynamicCache,
GenerationConfig,
LlamaConfig,
SinkCache,
StaticCache,
convert_and_export_with_cache,
)
from transformers.utils import is_torch_greater_or_equal
TEST_CACHE_IMPLEMENTATIONS = [
cache_name
for cache_name in ALL_CACHE_IMPLEMENTATIONS
# TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`?
if cache_name != "mamba"
# TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them
if cache_name != "offloaded_hybrid"
]
@require_torch
@@ -176,9 +185,121 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
@require_torch_accelerator
class CacheIntegrationTest(unittest.TestCase):
"""Cache tests that require loading models"""
"""Fast cache integration tests that share the same small model"""
@classmethod
def setUpClass(cls):
# Load once and reuse across tests
cls.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", padding_side="left")
cls.model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-135M-Instruct", device_map="auto", torch_dtype=torch.float16
)
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
def _skip_on_uninstalled_cache_dependencies(self, cache_implementation):
"""Function to skip tests on missing cache dependencies, given a cache implementation"""
if cache_implementation == "quantized" and not is_optimum_quanto_available():
self.skipTest("Quanto is not available")
if "offloaded" in cache_implementation:
has_accelerator = torch_device is not None and torch_device != "cpu"
if not has_accelerator:
self.skipTest("Offloaded caches require an accelerator")
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_batched(self, cache_implementation):
"""Sanity check: caches' `.update` function expects batched inputs"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
inputs = self.tokenizer(
["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
gen_out = self.model.generate(
**inputs,
do_sample=False,
max_new_tokens=10,
return_dict_in_generate=True,
cache_implementation=cache_implementation,
disable_compile=True,
)
# Sanity check: a cache was used
self.assertIsInstance(gen_out.past_key_values, Cache)
# Confirm that the output matches expectations
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_beam_search(self, cache_implementation):
"""
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
(an output sequence contains multiple beam indices).
"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
if cache_implementation == "offloaded_hybrid_chunked":
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
# different from the other caches.
self.skipTest("`offloaded_hybrid_chunked` fails this test")
EXPECTED_GENERATION = [
"Blue is the color of the sky, and the color of",
"Blue is the color of the sky, and the second is",
]
inputs = self.tokenizer(["Blue is"], return_tensors="pt").to(self.model.device)
gen_out = self.model.generate(
**inputs,
do_sample=False,
max_new_tokens=10,
num_beams=2,
num_return_sequences=2,
cache_implementation=cache_implementation,
disable_compile=True,
return_dict_in_generate=True,
)
# Sanity check: a cache was used
self.assertIsInstance(gen_out.past_key_values, Cache)
# At least one of the sequences requires multiple beam indices -> `reorder_cache` had to shift things around
self.assertTrue(any(len(set(beams_in_sequence)) > 1 for beams_in_sequence in gen_out.beam_indices))
# Confirm that the output matches expectations
decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the cache"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
inputs = self.tokenizer(["The cat"], padding=True, return_tensors="pt").to(self.model.device)
generation_kwargs = {
"do_sample": False,
"max_new_tokens": 10,
"cache_implementation": cache_implementation,
"disable_compile": True,
}
gen_out = self.model.generate(**inputs, **generation_kwargs)
decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
# Now with extra left-padding
inputs_expanded = self.tokenizer(["The cat"], padding=True, return_tensors="pt", pad_to_multiple_of=32)
inputs_expanded = inputs_expanded.to(self.model.device)
self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
gen_out = self.model.generate(**inputs_expanded, **generation_kwargs)
decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
@require_torch_accelerator
class CacheHardIntegrationTest(unittest.TestCase):
"""Hard cache integration tests that require loading different models"""
def tearDown(self):
# Some tests use large models, which might result in suboptimal torch re-allocation if we run multiple tests
@@ -187,18 +308,15 @@ class CacheIntegrationTest(unittest.TestCase):
@slow
def test_dynamic_cache_hard(self):
"""Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail"""
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
# DynamicCache and the legacy cache format should be equivalent
set_seed(0)
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
set_seed(0)
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = (
@@ -215,138 +333,11 @@ class CacheIntegrationTest(unittest.TestCase):
)
self.assertEqual(decoded[0], expected_text)
@slow
def test_dynamic_cache_batched(self):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt").to(
model.device
)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
self.assertListEqual(decoded, expected_text)
@slow
def test_dynamic_cache_beam_search(self):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(["The best color is"], return_tensors="pt").to(model.device)
gen_out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=20,
num_beams=2,
num_return_sequences=2,
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = [
"The best color is the one that makes you feel good.\nThe best color is the one that makes you feel good",
"The best color is the one that suits you.\nThe best color is the one that suits you. The",
]
self.assertListEqual(decoded, expected_text)
@slow
def test_hybrid_cache_n_sequences(self):
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-9b",
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
inputs = tokenizer(["Hello I am doing"], return_tensors="pt").to(model.device)
gen_out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=20,
num_return_sequences=2,
num_beams=2,
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = [
"Hello I am doing a project for my school and I am trying to make a program that will allow me to input a",
"Hello I am doing a project for my school and I am trying to make a program that will allow me to use a",
]
self.assertListEqual(decoded, expected_text)
@require_non_xpu
@require_gptq
@slow
def test_sink_cache_hard(self):
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")
inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
# Set up the SinkCache. Using a small window length to contain computational complexity. If this example is run
# without a SinkCache, the last few tokens are gibberish (ends in "of the of the of a of a of")
cache = SinkCache(window_length=508, num_sink_tokens=4)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
@slow
def test_sink_cache_iterative_prompts(self):
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
)
prompt = (
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
"and must-see attractions."
)
# Prepare generation settings
cache = SinkCache(window_length=256, num_sink_tokens=4)
input_ids = torch.tensor([], device=model.device, dtype=torch.int)
for _ in range(3):
# Tokenize the prompt with the correct chat template
chat = [{"role": "user", "content": prompt}]
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
)
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)
# Perform the generation
gen_out = model.generate(
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
)
input_ids = gen_out
# We went well beyond the cache length
self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5)
# And it still produces a coherent english
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
last_output = (
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
"was visiting the historic district of Honolulu. Here,"
)
self.assertTrue(decoded[0].endswith(last_output))
@parameterized.expand(
[
("eager", "static"),
("sdpa", "static"),
]
)
@parameterized.expand([("eager"), ("sdpa")])
@require_torch_gpu
@slow
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
"""Tests that different cache implementations work well with eager and SDPA inference"""
EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
@@ -371,124 +362,19 @@ class CacheIntegrationTest(unittest.TestCase):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.generation_config.cache_implementation = cache_implementation
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
gen_out = model.generate(
**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static", disable_compile=True
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.forward = torch.compile(model.forward)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, compiled"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
@slow
def test_dynamic_cache_extra_left_padding(self):
"""Tests that adding extra left-padding does not affect the generation with the dynamic cache"""
EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
]
tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
).to(torch_device)
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
# Now with extra left-padding
inputs_expanded = tokenizer(
["The best color is", "We should not undermind the issues at hand"],
padding=True,
return_tensors="pt",
pad_to_multiple_of=32,
).to(model.device)
self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
@slow
def test_static_cache_extra_left_padding(self):
"""Tests that adding extra left-padding does not affect the generation with the static cache"""
EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
]
tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
).to(torch_device)
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
model.generation_config.cache_implementation = "static"
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
# Now with extra left-padding
inputs_expanded = tokenizer(
["The best color is", "We should not undermind the issues at hand"],
padding=True,
return_tensors="pt",
pad_to_multiple_of=32,
).to(model.device)
self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertListEqual(decoded, EXPECTED_GENERATION)
@unittest.skip(reason="TODO @gante static cache's does not support beam search yet")
def test_static_cache_beam_search(self):
pass
@require_torch_accelerator
@slow
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
"num_beams": 4,
"num_beam_groups": 2,
"num_return_sequences": 4,
"diversity_penalty": 1.0,
"max_new_tokens": 20,
"early_stopping": True,
}
original = GenerationConfig(**common)
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
original_outputs = model.generate(generation_config=original, **inputs)
offloaded_outputs = model.generate(generation_config=offloaded, **inputs)
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
assert torch.all(original_output == offloaded_output).item()
@require_torch_accelerator
@slow
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
@@ -526,12 +412,14 @@ class CacheIntegrationTest(unittest.TestCase):
torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=offloaded, **inputs)
offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}")
assert offloaded_peak_memory < original_peak_memory
self.assertTrue(offloaded_peak_memory < original_peak_memory)
@require_torch_gpu
@slow
def test_cache_copy(self):
"""Tests that we can manually set a cache, copy, and reuse it for generation"""
# TODO (joao): test for all cache implementations in `CacheIntegrationTest` after standardizing the
# lazy init of cache layers
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
@@ -542,7 +430,7 @@ class CacheIntegrationTest(unittest.TestCase):
INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
# This is the common prompt cached, we need to run forward without grad to be abel to copy
# This is the common prompt cached, we need to run forward without grad to be able to copy
with torch.no_grad():
prompt_cache = model(**inputs_initial_prompt, past_key_values=prompt_cache).past_key_values
@@ -551,14 +439,19 @@ class CacheIntegrationTest(unittest.TestCase):
for prompt in prompts:
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values, max_new_tokens=40)
outputs = model.generate(
**new_inputs, past_key_values=past_key_values, max_new_tokens=40, disable_compile=True
)
response = tokenizer.batch_decode(outputs)[0]
responses.append(response)
EXPECTED_DECODED_TEXT = [
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
] # fmt: skip
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is a wonderful "
"way to explore new places, cultures, and experiences. Whether you are a seasoned traveler or a "
"first-time adventurer, there is always something",
"You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital "
"of France.\n\n\n\n\n\n\n<|endoftext|>",
]
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
@require_torch_multi_gpu
@@ -609,7 +502,7 @@ class CacheIntegrationTest(unittest.TestCase):
# on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
with CaptureStderr() as cap:
model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
self.assertEqual(cap.err, "")
self.assertNotIn("cuda", cap.err.lower())
@require_torch_multi_gpu
@slow