Tests: remove cuda versions when the result is the same 🧹🧹 (#31955)
remove cuda versions when the result is the same
This commit is contained in:
@@ -538,10 +538,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_model_7b_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = {
|
||||
7: "My favourite condiment is 100% ketchup. I’m not a fan of mustard, mayo,",
|
||||
8: "My favourite condiment is 100% ketchup. I’m not a fan of mustard, mayo,",
|
||||
}
|
||||
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% ketchup. I’m not a fan of mustard, mayo,"
|
||||
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
@@ -553,7 +550,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
def test_model_7b_dola_generation(self):
|
||||
@@ -641,15 +638,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_speculative_generation(self):
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXT_COMPLETION = {
|
||||
7: "My favourite condiment is 100% ketchup. I love it on everything. I’m not a big",
|
||||
8: "My favourite condiment is 100% ketchup. I love it on everything. I’m not a big",
|
||||
9: "My favourite condiment is 100% ketchup. I love it on everything. I’m not a big",
|
||||
}
|
||||
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
@@ -663,7 +652,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
@@ -677,16 +666,10 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")
|
||||
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
EXPECTED_TEXT_COMPLETION = {
|
||||
8: [
|
||||
"My favourite condiment is 100% ketchup. I love it on everything. "
|
||||
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
|
||||
],
|
||||
7: [
|
||||
"My favourite condiment is 100% ketchup. I love it on everything. "
|
||||
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
|
||||
],
|
||||
}
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"My favourite condiment is 100% ketchup. I love it on everything. "
|
||||
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
|
||||
]
|
||||
|
||||
prompts = ["My favourite condiment is "]
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
@@ -699,21 +682,21 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
# Dynamic Cache
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
|
||||
|
||||
# Static Cache
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||
|
||||
# Sliding Window Cache
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||
|
||||
# Static Cache + compile
|
||||
forward_function = model.forward
|
||||
@@ -722,7 +705,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
|
||||
|
||||
# Sliding Window Cache + compile
|
||||
torch._dynamo.reset()
|
||||
@@ -731,7 +714,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user