[tests] Smaller model in slow cache tests (#37922)
This commit is contained in:
@@ -24,6 +24,7 @@ from transformers.testing_utils import (
|
|||||||
cleanup,
|
cleanup,
|
||||||
get_gpu_count,
|
get_gpu_count,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@@ -301,37 +302,52 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
class CacheHardIntegrationTest(unittest.TestCase):
|
class CacheHardIntegrationTest(unittest.TestCase):
|
||||||
"""Hard cache integration tests that require loading different models"""
|
"""Hard cache integration tests that require loading different models"""
|
||||||
|
|
||||||
def tearDown(self):
|
def setUp(self):
|
||||||
# Some tests use large models, which might result in suboptimal torch re-allocation if we run multiple tests
|
# Clears memory before each test. Some tests use large models, which might result in suboptimal torch
|
||||||
# in a row
|
# re-allocation if we run multiple tests in a row without clearing memory.
|
||||||
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
# Clears memory after the last test. See `setUp` for more details.
|
||||||
cleanup(torch_device, gc_collect=True)
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_dynamic_cache_hard(self):
|
def test_dynamic_cache_hard(self):
|
||||||
"""Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail"""
|
"""Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail"""
|
||||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="auto", torch_dtype=torch.bfloat16)
|
||||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
|
||||||
)
|
|
||||||
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
|
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256)
|
gen_out = model.generate(
|
||||||
|
**inputs, do_sample=True, max_new_tokens=256, return_dict_in_generate=True, output_scores=True
|
||||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
|
||||||
expected_text = (
|
|
||||||
"Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like "
|
|
||||||
"to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n"
|
|
||||||
"Cats are also very independent. They don't like to be told what to do, and they don't like to be told "
|
|
||||||
"what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats "
|
|
||||||
"are also very curious. They like to explore, and they like to play. They are also very fast. They can "
|
|
||||||
"run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they "
|
|
||||||
"can solve problems. They are also very playful. They like to play with toys, and they like to play with "
|
|
||||||
"other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They "
|
|
||||||
"also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to "
|
|
||||||
"clean their litter box.\nCats are also very independent. They don't"
|
|
||||||
)
|
)
|
||||||
self.assertEqual(decoded[0], expected_text)
|
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||||
|
# sum of the scores for the generated tokens
|
||||||
|
input_length = inputs.input_ids.shape[1]
|
||||||
|
score_sum = sum(
|
||||||
|
[score[0][gen_out.sequences[0][input_length + idx]] for idx, score in enumerate(gen_out.scores)]
|
||||||
|
)
|
||||||
|
|
||||||
|
EXPECTED_GENERATION = (
|
||||||
|
"Here's everything I know about cats. Cats are mammals, they have four legs, they have a tail, they have "
|
||||||
|
"a face with a nose, eyes, and mouth. They have fur, they have claws, and they have a body that is "
|
||||||
|
"covered in fur. They are carnivores, so they eat meat. They are also very clean animals, they groom "
|
||||||
|
"themselves. They have a lot of different breeds. Some are small, some are large. Some are friendly, "
|
||||||
|
"some are not. They have a lot of different personalities. They can be very independent, or they can be "
|
||||||
|
"very affectionate. They can be very playful, or they can be very lazy. They can be very intelligent, or "
|
||||||
|
"they can be very silly. They have a lot of different behaviors. They can be very curious, or they can "
|
||||||
|
"be very cautious. They can be very vocal, or they can be very quiet. They can be very social, or they "
|
||||||
|
"can be very solitary. They can be very active, or they can be very inactive. They can be very "
|
||||||
|
"affectionate, or they can be very aloof. They can be very playful, or they can be very lazy. They can "
|
||||||
|
"be very intelligent, or they can be very silly. They have a lot of different behaviors. They can be "
|
||||||
|
"very curious, or they can"
|
||||||
|
)
|
||||||
|
EXPECTED_SCORE_SUM = 11017.4971
|
||||||
|
self.assertEqual(decoded[0], EXPECTED_GENERATION)
|
||||||
|
self.assertAlmostEqual(score_sum, EXPECTED_SCORE_SUM, places=2)
|
||||||
|
self.assertIsInstance(gen_out.past_key_values, DynamicCache) # sanity check
|
||||||
|
|
||||||
@parameterized.expand([("eager"), ("sdpa")])
|
@parameterized.expand([("eager"), ("sdpa")])
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@@ -339,41 +355,42 @@ class CacheHardIntegrationTest(unittest.TestCase):
|
|||||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
|
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
|
||||||
"""Tests that different cache implementations work well with eager and SDPA inference"""
|
"""Tests that different cache implementations work well with eager and SDPA inference"""
|
||||||
EXPECTED_GENERATION = [
|
EXPECTED_GENERATION = [
|
||||||
"The best color is the one that complements the skin tone of the",
|
"The best color is the one that is most suitable for the purpose.",
|
||||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
"We should not undermind the issues at hand, but instead, we should focus on the things",
|
||||||
]
|
]
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left")
|
||||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
|
||||||
)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
"NousResearch/Llama-2-7b-chat-hf",
|
"Qwen/Qwen3-4B",
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
).to(torch_device)
|
device_map="auto",
|
||||||
|
)
|
||||||
inputs = tokenizer(
|
inputs = tokenizer(
|
||||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||||
).to(model.device)
|
).to(model.device)
|
||||||
|
generation_kwargs = {"do_sample": False, "max_new_tokens": 10, "return_dict_in_generate": True}
|
||||||
|
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
gen_out = model.generate(**inputs, **generation_kwargs)
|
||||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||||
with self.subTest(f"{attn_implementation}, dynamic"):
|
with self.subTest(f"{attn_implementation}, dynamic"):
|
||||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||||
|
self.assertIsInstance(gen_out.past_key_values, DynamicCache) # sanity check
|
||||||
|
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
gen_out = model.generate(
|
gen_out = model.generate(**inputs, **generation_kwargs, cache_implementation="static", disable_compile=True)
|
||||||
**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static", disable_compile=True
|
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||||
)
|
|
||||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
|
||||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||||
|
self.assertIsInstance(gen_out.past_key_values, StaticCache) # sanity check
|
||||||
|
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static")
|
gen_out = model.generate(**inputs, **generation_kwargs, cache_implementation="static")
|
||||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True)
|
||||||
with self.subTest(f"{attn_implementation}, static, compiled"):
|
with self.subTest(f"{attn_implementation}, static, compiled"):
|
||||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||||
|
self.assertIsInstance(gen_out.past_key_values, StaticCache) # sanity check
|
||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@slow
|
@slow
|
||||||
@@ -446,9 +463,9 @@ class CacheHardIntegrationTest(unittest.TestCase):
|
|||||||
responses.append(response)
|
responses.append(response)
|
||||||
|
|
||||||
EXPECTED_DECODED_TEXT = [
|
EXPECTED_DECODED_TEXT = [
|
||||||
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is a wonderful "
|
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an "
|
||||||
"way to explore new places, cultures, and experiences. Whether you are a seasoned traveler or a "
|
"enriching experience that broadens our horizons and allows us to explore the world beyond our comfort "
|
||||||
"first-time adventurer, there is always something",
|
"zones. Whether it's a short weekend getaway",
|
||||||
"You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital "
|
"You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital "
|
||||||
"of France.\n\n\n\n\n\n\n<|endoftext|>",
|
"of France.\n\n\n\n\n\n\n<|endoftext|>",
|
||||||
]
|
]
|
||||||
@@ -506,6 +523,7 @@ class CacheHardIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
@slow
|
@slow
|
||||||
|
@require_read_token
|
||||||
def test_static_cache_multi_gpu(self):
|
def test_static_cache_multi_gpu(self):
|
||||||
"""Regression test for #35164: static cache with multi-gpu"""
|
"""Regression test for #35164: static cache with multi-gpu"""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user