[tests] Smaller model in slow cache tests (#37922)

This commit is contained in:
Joao Gante
2025-05-06 11:15:25 +01:00
committed by GitHub
parent ff5ef95db7
commit 9981214d32

View File

@@ -24,6 +24,7 @@ from transformers.testing_utils import (
cleanup, cleanup,
get_gpu_count, get_gpu_count,
is_torch_available, is_torch_available,
require_read_token,
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
require_torch_gpu, require_torch_gpu,
@@ -301,37 +302,52 @@ class CacheIntegrationTest(unittest.TestCase):
class CacheHardIntegrationTest(unittest.TestCase): class CacheHardIntegrationTest(unittest.TestCase):
"""Hard cache integration tests that require loading different models""" """Hard cache integration tests that require loading different models"""
def tearDown(self): def setUp(self):
# Some tests use large models, which might result in suboptimal torch re-allocation if we run multiple tests # Clears memory before each test. Some tests use large models, which might result in suboptimal torch
# in a row # re-allocation if we run multiple tests in a row without clearing memory.
cleanup(torch_device, gc_collect=True)
@classmethod
def tearDownClass(cls):
# Clears memory after the last test. See `setUp` for more details.
cleanup(torch_device, gc_collect=True) cleanup(torch_device, gc_collect=True)
@slow @slow
def test_dynamic_cache_hard(self): def test_dynamic_cache_hard(self):
"""Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail""" """Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail"""
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="auto", torch_dtype=torch.bfloat16)
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device) inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
set_seed(0) set_seed(0)
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256) gen_out = model.generate(
**inputs, do_sample=True, max_new_tokens=256, return_dict_in_generate=True, output_scores=True
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = (
"Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like "
"to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n"
"Cats are also very independent. They don't like to be told what to do, and they don't like to be told "
"what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats "
"are also very curious. They like to explore, and they like to play. They are also very fast. They can "
"run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they "
"can solve problems. They are also very playful. They like to play with toys, and they like to play with "
"other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They "
"also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to "
"clean their litter box.\nCats are also very independent. They don't"
) )
self.assertEqual(decoded[0], expected_text) decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
# sum of the scores for the generated tokens
input_length = inputs.input_ids.shape[1]
score_sum = sum(
[score[0][gen_out.sequences[0][input_length + idx]] for idx, score in enumerate(gen_out.scores)]
)
EXPECTED_GENERATION = (
"Here's everything I know about cats. Cats are mammals, they have four legs, they have a tail, they have "
"a face with a nose, eyes, and mouth. They have fur, they have claws, and they have a body that is "
"covered in fur. They are carnivores, so they eat meat. They are also very clean animals, they groom "
"themselves. They have a lot of different breeds. Some are small, some are large. Some are friendly, "
"some are not. They have a lot of different personalities. They can be very independent, or they can be "
"very affectionate. They can be very playful, or they can be very lazy. They can be very intelligent, or "
"they can be very silly. They have a lot of different behaviors. They can be very curious, or they can "
"be very cautious. They can be very vocal, or they can be very quiet. They can be very social, or they "
"can be very solitary. They can be very active, or they can be very inactive. They can be very "
"affectionate, or they can be very aloof. They can be very playful, or they can be very lazy. They can "
"be very intelligent, or they can be very silly. They have a lot of different behaviors. They can be "
"very curious, or they can"
)
EXPECTED_SCORE_SUM = 11017.4971
self.assertEqual(decoded[0], EXPECTED_GENERATION)
self.assertAlmostEqual(score_sum, EXPECTED_SCORE_SUM, places=2)
self.assertIsInstance(gen_out.past_key_values, DynamicCache) # sanity check
@parameterized.expand([("eager"), ("sdpa")]) @parameterized.expand([("eager"), ("sdpa")])
@require_torch_gpu @require_torch_gpu
@@ -339,41 +355,42 @@ class CacheHardIntegrationTest(unittest.TestCase):
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation): def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
"""Tests that different cache implementations work well with eager and SDPA inference""" """Tests that different cache implementations work well with eager and SDPA inference"""
EXPECTED_GENERATION = [ EXPECTED_GENERATION = [
"The best color is the one that complements the skin tone of the", "The best color is the one that is most suitable for the purpose.",
"We should not undermind the issues at hand.\nWe should not undermind the issues", "We should not undermind the issues at hand, but instead, we should focus on the things",
] ]
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left")
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", "Qwen/Qwen3-4B",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
).to(torch_device) device_map="auto",
)
inputs = tokenizer( inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device) ).to(model.device)
generation_kwargs = {"do_sample": False, "max_new_tokens": 10, "return_dict_in_generate": True}
set_seed(0) set_seed(0)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) gen_out = model.generate(**inputs, **generation_kwargs)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, dynamic"): with self.subTest(f"{attn_implementation}, dynamic"):
self.assertListEqual(decoded, EXPECTED_GENERATION) self.assertListEqual(decoded, EXPECTED_GENERATION)
self.assertIsInstance(gen_out.past_key_values, DynamicCache) # sanity check
set_seed(0) set_seed(0)
gen_out = model.generate( gen_out = model.generate(**inputs, **generation_kwargs, cache_implementation="static", disable_compile=True)
**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static", disable_compile=True decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"): with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION) self.assertListEqual(decoded, EXPECTED_GENERATION)
self.assertIsInstance(gen_out.past_key_values, StaticCache) # sanity check
set_seed(0) set_seed(0)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static") gen_out = model.generate(**inputs, **generation_kwargs, cache_implementation="static")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, compiled"): with self.subTest(f"{attn_implementation}, static, compiled"):
self.assertListEqual(decoded, EXPECTED_GENERATION) self.assertListEqual(decoded, EXPECTED_GENERATION)
self.assertIsInstance(gen_out.past_key_values, StaticCache) # sanity check
@require_torch_accelerator @require_torch_accelerator
@slow @slow
@@ -446,9 +463,9 @@ class CacheHardIntegrationTest(unittest.TestCase):
responses.append(response) responses.append(response)
EXPECTED_DECODED_TEXT = [ EXPECTED_DECODED_TEXT = [
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is a wonderful " "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an "
"way to explore new places, cultures, and experiences. Whether you are a seasoned traveler or a " "enriching experience that broadens our horizons and allows us to explore the world beyond our comfort "
"first-time adventurer, there is always something", "zones. Whether it's a short weekend getaway",
"You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital " "You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital "
"of France.\n\n\n\n\n\n\n<|endoftext|>", "of France.\n\n\n\n\n\n\n<|endoftext|>",
] ]
@@ -506,6 +523,7 @@ class CacheHardIntegrationTest(unittest.TestCase):
@require_torch_multi_gpu @require_torch_multi_gpu
@slow @slow
@require_read_token
def test_static_cache_multi_gpu(self): def test_static_cache_multi_gpu(self):
"""Regression test for #35164: static cache with multi-gpu""" """Regression test for #35164: static cache with multi-gpu"""