[tests] Smaller model in slow cache tests (#37922)
This commit is contained in:
@@ -24,6 +24,7 @@ from transformers.testing_utils import (
|
||||
cleanup,
|
||||
get_gpu_count,
|
||||
is_torch_available,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
@@ -301,37 +302,52 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
class CacheHardIntegrationTest(unittest.TestCase):
|
||||
"""Hard cache integration tests that require loading different models"""
|
||||
|
||||
def tearDown(self):
|
||||
# Some tests use large models, which might result in suboptimal torch re-allocation if we run multiple tests
|
||||
# in a row
|
||||
def setUp(self):
|
||||
# Clears memory before each test. Some tests use large models, which might result in suboptimal torch
|
||||
# re-allocation if we run multiple tests in a row without clearing memory.
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
# Clears memory after the last test. See `setUp` for more details.
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@slow
|
||||
def test_dynamic_cache_hard(self):
|
||||
"""Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="auto", torch_dtype=torch.bfloat16)
|
||||
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
|
||||
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256)
|
||||
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = (
|
||||
"Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like "
|
||||
"to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n"
|
||||
"Cats are also very independent. They don't like to be told what to do, and they don't like to be told "
|
||||
"what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats "
|
||||
"are also very curious. They like to explore, and they like to play. They are also very fast. They can "
|
||||
"run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they "
|
||||
"can solve problems. They are also very playful. They like to play with toys, and they like to play with "
|
||||
"other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They "
|
||||
"also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to "
|
||||
"clean their litter box.\nCats are also very independent. They don't"
|
||||
gen_out = model.generate(
|
||||
**inputs, do_sample=True, max_new_tokens=256, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
self.assertEqual(decoded[0], expected_text)
|
||||
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
# sum of the scores for the generated tokens
|
||||
input_length = inputs.input_ids.shape[1]
|
||||
score_sum = sum(
|
||||
[score[0][gen_out.sequences[0][input_length + idx]] for idx, score in enumerate(gen_out.scores)]
|
||||
)
|
||||
|
||||
EXPECTED_GENERATION = (
|
||||
"Here's everything I know about cats. Cats are mammals, they have four legs, they have a tail, they have "
|
||||
"a face with a nose, eyes, and mouth. They have fur, they have claws, and they have a body that is "
|
||||
"covered in fur. They are carnivores, so they eat meat. They are also very clean animals, they groom "
|
||||
"themselves. They have a lot of different breeds. Some are small, some are large. Some are friendly, "
|
||||
"some are not. They have a lot of different personalities. They can be very independent, or they can be "
|
||||
"very affectionate. They can be very playful, or they can be very lazy. They can be very intelligent, or "
|
||||
"they can be very silly. They have a lot of different behaviors. They can be very curious, or they can "
|
||||
"be very cautious. They can be very vocal, or they can be very quiet. They can be very social, or they "
|
||||
"can be very solitary. They can be very active, or they can be very inactive. They can be very "
|
||||
"affectionate, or they can be very aloof. They can be very playful, or they can be very lazy. They can "
|
||||
"be very intelligent, or they can be very silly. They have a lot of different behaviors. They can be "
|
||||
"very curious, or they can"
|
||||
)
|
||||
EXPECTED_SCORE_SUM = 11017.4971
|
||||
self.assertEqual(decoded[0], EXPECTED_GENERATION)
|
||||
self.assertAlmostEqual(score_sum, EXPECTED_SCORE_SUM, places=2)
|
||||
self.assertIsInstance(gen_out.past_key_values, DynamicCache) # sanity check
|
||||
|
||||
@parameterized.expand([("eager"), ("sdpa")])
|
||||
@require_torch_gpu
|
||||
@@ -339,41 +355,42 @@ class CacheHardIntegrationTest(unittest.TestCase):
|
||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
|
||||
"""Tests that different cache implementations work well with eager and SDPA inference"""
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
"The best color is the one that is most suitable for the purpose.",
|
||||
"We should not undermind the issues at hand, but instead, we should focus on the things",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf",
|
||||
"Qwen/Qwen3-4B",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
).to(torch_device)
|
||||
device_map="auto",
|
||||
)
|
||||
inputs = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
generation_kwargs = {"do_sample": False, "max_new_tokens": 10, "return_dict_in_generate": True}
|
||||
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
gen_out = model.generate(**inputs, **generation_kwargs)
|
||||
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, dynamic"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
self.assertIsInstance(gen_out.past_key_values, DynamicCache) # sanity check
|
||||
|
||||
set_seed(0)
|
||||
gen_out = model.generate(
|
||||
**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static", disable_compile=True
|
||||
)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
gen_out = model.generate(**inputs, **generation_kwargs, cache_implementation="static", disable_compile=True)
|
||||
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
self.assertIsInstance(gen_out.past_key_values, StaticCache) # sanity check
|
||||
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static")
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
gen_out = model.generate(**inputs, **generation_kwargs, cache_implementation="static")
|
||||
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, compiled"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
self.assertIsInstance(gen_out.past_key_values, StaticCache) # sanity check
|
||||
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
@@ -446,9 +463,9 @@ class CacheHardIntegrationTest(unittest.TestCase):
|
||||
responses.append(response)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is a wonderful "
|
||||
"way to explore new places, cultures, and experiences. Whether you are a seasoned traveler or a "
|
||||
"first-time adventurer, there is always something",
|
||||
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an "
|
||||
"enriching experience that broadens our horizons and allows us to explore the world beyond our comfort "
|
||||
"zones. Whether it's a short weekend getaway",
|
||||
"You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital "
|
||||
"of France.\n\n\n\n\n\n\n<|endoftext|>",
|
||||
]
|
||||
@@ -506,6 +523,7 @@ class CacheHardIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_static_cache_multi_gpu(self):
|
||||
"""Regression test for #35164: static cache with multi-gpu"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user