Fix slow tests for important models to be compatible with A10 runners (#29905)
* fix mistral and mixtral * add pdb * fix mixtral tesst * fix * fix mistral ? * add fix gemma * fix mistral * fix * test * anoter test * fix * fix * fix mistral tests * fix them again * final fixes for mistral * fix padding right * fix whipser fa2 * fix * fix * fix gemma * test * fix llama * fix * fix * fix llama gemma * add class attribute * fix CI * clarify whisper * compute_capability * rename names in some comments * Add # fmt: skip * make style * Update tests/models/mistral/test_modeling_mistral.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update * update --------- Co-authored-by: Younes Belkada <younesbelkada@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
@@ -379,40 +380,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
||||
|
||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@@ -500,6 +467,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@is_flaky
|
||||
@slow
|
||||
def test_flash_attn_2_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -531,12 +499,21 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
assert torch.allclose(logits_fa, logits, atol=3e-3)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
@require_read_token
|
||||
@require_torch_gpu
|
||||
class GemmaIntegrationTest(unittest.TestCase):
|
||||
input_text = ["Hello I am doing", "Hi today"]
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_fp32(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
@@ -554,6 +531,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_fp16(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
@@ -573,6 +551,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_fp16_static_cache(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
@@ -594,12 +573,19 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_bf16(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||
]
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||
],
|
||||
8: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
@@ -611,14 +597,21 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_eager(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
]
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
8: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||
@@ -631,15 +624,22 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@require_torch_sdpa
|
||||
@require_read_token
|
||||
def test_model_2b_sdpa(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||
]
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||
],
|
||||
8: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||
@@ -652,10 +652,11 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@pytest.mark.flash_attn_test
|
||||
@require_flash_attn
|
||||
@require_read_token
|
||||
def test_model_2b_flash_attn(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
@@ -677,6 +678,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_model_2b_4bit(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
@@ -695,6 +697,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@unittest.skip("The test will not fit our CI runners")
|
||||
@require_read_token
|
||||
def test_model_7b_fp32(self):
|
||||
model_id = "google/gemma-7b"
|
||||
EXPECTED_TEXTS = [
|
||||
@@ -712,6 +715,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_7b_fp16(self):
|
||||
model_id = "google/gemma-7b"
|
||||
EXPECTED_TEXTS = [
|
||||
@@ -731,12 +735,19 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_7b_bf16(self):
|
||||
model_id = "google/gemma-7b"
|
||||
EXPECTED_TEXTS = [
|
||||
"""Hello I am doing a project on a 1991 240sx and I am trying to find""",
|
||||
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
||||
]
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"""Hello I am doing a project on a 1991 240sx and I am trying to find""",
|
||||
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
||||
],
|
||||
8: [
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will read a .txt file",
|
||||
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
@@ -748,8 +759,9 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@require_read_token
|
||||
def test_model_7b_fp16_static_cache(self):
|
||||
model_id = "google/gemma-7b"
|
||||
EXPECTED_TEXTS = [
|
||||
@@ -772,12 +784,19 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_model_7b_4bit(self):
|
||||
model_id = "google/gemma-7b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
|
||||
"""Hi today I am going to talk about the new update for the game called "The new update" and I""",
|
||||
]
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
|
||||
"""Hi today I am going to talk about the new update for the game called "The new update" and I""",
|
||||
],
|
||||
8: [
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
|
||||
"Hi today I am going to talk about the best way to get rid of acne. miniaturing is a very",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, load_in_4bit=True)
|
||||
|
||||
@@ -787,4 +806,4 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
Reference in New Issue
Block a user