Fix slow tests for important models to be compatible with A10 runners (#29905)
* fix mistral and mixtral * add pdb * fix mixtral tesst * fix * fix mistral ? * add fix gemma * fix mistral * fix * test * anoter test * fix * fix * fix mistral tests * fix them again * final fixes for mistral * fix padding right * fix whipser fa2 * fix * fix * fix gemma * test * fix llama * fix * fix * fix llama gemma * add class attribute * fix CI * clarify whisper * compute_capability * rename names in some comments * Add # fmt: skip * make style * Update tests/models/mistral/test_modeling_mistral.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update * update --------- Co-authored-by: Younes Belkada <younesbelkada@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from parameterized import parameterized
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
is_flaky,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
@@ -379,40 +380,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
|
||||||
|
|
||||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@@ -500,6 +467,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@is_flaky
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_equivalence(self):
|
def test_flash_attn_2_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -531,12 +499,21 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
assert torch.allclose(logits_fa, logits, atol=3e-3)
|
assert torch.allclose(logits_fa, logits, atol=3e-3)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_torch_gpu
|
||||||
class GemmaIntegrationTest(unittest.TestCase):
|
class GemmaIntegrationTest(unittest.TestCase):
|
||||||
input_text = ["Hello I am doing", "Hi today"]
|
input_text = ["Hello I am doing", "Hi today"]
|
||||||
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
|
# Depending on the hardware we get different logits / generations
|
||||||
|
cuda_compute_capability_major_version = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if is_torch_available() and torch.cuda.is_available():
|
||||||
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_model_2b_fp32(self):
|
def test_model_2b_fp32(self):
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
@@ -554,6 +531,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_model_2b_fp16(self):
|
def test_model_2b_fp16(self):
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
@@ -573,6 +551,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_model_2b_fp16_static_cache(self):
|
def test_model_2b_fp16_static_cache(self):
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
@@ -594,12 +573,19 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_model_2b_bf16(self):
|
def test_model_2b_bf16(self):
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = {
|
||||||
|
7: [
|
||||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||||
]
|
],
|
||||||
|
8: [
|
||||||
|
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||||
|
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||||
torch_device
|
torch_device
|
||||||
@@ -611,14 +597,21 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_model_2b_eager(self):
|
def test_model_2b_eager(self):
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = {
|
||||||
|
7: [
|
||||||
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
|
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
|
||||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||||
]
|
],
|
||||||
|
8: [
|
||||||
|
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||||
|
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||||
@@ -631,15 +624,22 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||||
|
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
|
@require_read_token
|
||||||
def test_model_2b_sdpa(self):
|
def test_model_2b_sdpa(self):
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = {
|
||||||
|
7: [
|
||||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||||
]
|
],
|
||||||
|
8: [
|
||||||
|
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||||
|
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||||
@@ -652,10 +652,11 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||||
|
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
|
@require_read_token
|
||||||
def test_model_2b_flash_attn(self):
|
def test_model_2b_flash_attn(self):
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
@@ -677,6 +678,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
|
@require_read_token
|
||||||
def test_model_2b_4bit(self):
|
def test_model_2b_4bit(self):
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
@@ -695,6 +697,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
@unittest.skip("The test will not fit our CI runners")
|
@unittest.skip("The test will not fit our CI runners")
|
||||||
|
@require_read_token
|
||||||
def test_model_7b_fp32(self):
|
def test_model_7b_fp32(self):
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
@@ -712,6 +715,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_model_7b_fp16(self):
|
def test_model_7b_fp16(self):
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
@@ -731,12 +735,19 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_model_7b_bf16(self):
|
def test_model_7b_bf16(self):
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = {
|
||||||
|
7: [
|
||||||
"""Hello I am doing a project on a 1991 240sx and I am trying to find""",
|
"""Hello I am doing a project on a 1991 240sx and I am trying to find""",
|
||||||
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
||||||
]
|
],
|
||||||
|
8: [
|
||||||
|
"Hello I am doing a project for my school and I am trying to make a program that will read a .txt file",
|
||||||
|
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||||
torch_device
|
torch_device
|
||||||
@@ -748,8 +759,9 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||||
|
|
||||||
|
@require_read_token
|
||||||
def test_model_7b_fp16_static_cache(self):
|
def test_model_7b_fp16_static_cache(self):
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
@@ -772,12 +784,19 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
|
@require_read_token
|
||||||
def test_model_7b_4bit(self):
|
def test_model_7b_4bit(self):
|
||||||
model_id = "google/gemma-7b"
|
model_id = "google/gemma-7b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = {
|
||||||
|
7: [
|
||||||
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
|
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
|
||||||
"""Hi today I am going to talk about the new update for the game called "The new update" and I""",
|
"""Hi today I am going to talk about the new update for the game called "The new update" and I""",
|
||||||
]
|
],
|
||||||
|
8: [
|
||||||
|
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
|
||||||
|
"Hi today I am going to talk about the best way to get rid of acne. miniaturing is a very",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, load_in_4bit=True)
|
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, load_in_4bit=True)
|
||||||
|
|
||||||
@@ -787,4 +806,4 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||||
|
|||||||
@@ -597,8 +597,18 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch_gpu
|
||||||
class LlamaIntegrationTest(unittest.TestCase):
|
class LlamaIntegrationTest(unittest.TestCase):
|
||||||
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
|
# Depending on the hardware we get different logits / generations
|
||||||
|
cuda_compute_capability_major_version = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if is_torch_available() and torch.cuda.is_available():
|
||||||
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
|
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
|
||||||
@slow
|
@slow
|
||||||
def test_model_7b_logits(self):
|
def test_model_7b_logits(self):
|
||||||
@@ -675,16 +685,25 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
@require_read_token
|
@require_read_token
|
||||||
def test_compile_static_cache(self):
|
def test_compile_static_cache(self):
|
||||||
NUM_TOKENS_TO_GENERATE = 40
|
NUM_TOKENS_TO_GENERATE = 40
|
||||||
EXPECTED_TEXT_COMPLETION = [
|
EXPECTED_TEXT_COMPLETION = {
|
||||||
|
7: [
|
||||||
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
|
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
|
||||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||||
]
|
],
|
||||||
|
8: [
|
||||||
|
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
|
||||||
|
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Simply put, the theory of relativity states that ",
|
"Simply put, the theory of relativity states that ",
|
||||||
"My favorite all time favorite condiment is ketchup.",
|
"My favorite all time favorite condiment is ketchup.",
|
||||||
]
|
]
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
||||||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||||
|
|
||||||
def decode_one_tokens(model, cur_token, input_pos, cache_position):
|
def decode_one_tokens(model, cur_token, input_pos, cache_position):
|
||||||
@@ -718,7 +737,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
cache_position += 1
|
cache_position += 1
|
||||||
|
|
||||||
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -763,6 +782,7 @@ end
|
|||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@slow
|
@slow
|
||||||
|
@unittest.skip("Model is too large")
|
||||||
def test_model_7b_logits(self):
|
def test_model_7b_logits(self):
|
||||||
model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf").to(torch_device)
|
model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf").to(torch_device)
|
||||||
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
|
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
|
||||||
|
|||||||
@@ -470,39 +470,68 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
self.skipTest("Mistral flash attention does not support right padding")
|
self.skipTest("Mistral flash attention does not support right padding")
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch_gpu
|
||||||
class MistralIntegrationTest(unittest.TestCase):
|
class MistralIntegrationTest(unittest.TestCase):
|
||||||
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
|
# Depending on the hardware we get different logits / generations
|
||||||
|
cuda_compute_capability_major_version = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if is_torch_available() and torch.cuda.is_available():
|
||||||
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_7b_logits(self):
|
def test_model_7b_logits(self):
|
||||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||||
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")
|
model = MistralForCausalLM.from_pretrained(
|
||||||
|
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = model(input_ids).logits.cpu()
|
out = model(input_ids).logits.cpu()
|
||||||
# Expected mean on dim = -1
|
# Expected mean on dim = -1
|
||||||
EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
|
EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
|
||||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||||
# slicing logits[0, 0, 0:30]
|
|
||||||
EXPECTED_SLICE = torch.tensor([-5.8781, -5.8616, -0.1052, -4.7200, -5.8781, -5.8774, -5.8773, -5.8777, -5.8781, -5.8780, -5.8781, -5.8779, -1.0787, 1.7583, -5.8779, -5.8780, -5.8783, -5.8778, -5.8776, -5.8781, -5.8784, -5.8778, -5.8778, -5.8777, -5.8779, -5.8778, -5.8776, -5.8780, -5.8779, -5.8781]) # fmt: skip
|
EXPECTED_SLICE = {
|
||||||
|
7: torch.tensor([-5.8781, -5.8616, -0.1052, -4.7200, -5.8781, -5.8774, -5.8773, -5.8777, -5.8781, -5.8780, -5.8781, -5.8779, -1.0787, 1.7583, -5.8779, -5.8780, -5.8783, -5.8778, -5.8776, -5.8781, -5.8784, -5.8778, -5.8778, -5.8777, -5.8779, -5.8778, -5.8776, -5.8780, -5.8779, -5.8781]),
|
||||||
|
8: torch.tensor([-5.8711, -5.8555, -0.1050, -4.7148, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -1.0781, 1.7568, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711]),
|
||||||
|
} # fmt: skip
|
||||||
|
|
||||||
print(out[0, 0, :30])
|
print(out[0, 0, :30])
|
||||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4)
|
torch.testing.assert_close(
|
||||||
|
out[0, 0, :30], EXPECTED_SLICE[self.cuda_compute_capability_major_version], atol=1e-4, rtol=1e-4
|
||||||
|
)
|
||||||
|
|
||||||
del model
|
del model
|
||||||
backend_empty_cache(torch_device)
|
backend_empty_cache(torch_device)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@require_bitsandbytes
|
||||||
def test_model_7b_generation(self):
|
def test_model_7b_generation(self):
|
||||||
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"""
|
EXPECTED_TEXT_COMPLETION = {
|
||||||
|
7: "My favourite condiment is 100% ketchup. I love it on everything. I'm not a big",
|
||||||
|
8: "My favourite condiment is 100% ketchup. I’m not a fan of mustard, mayo,",
|
||||||
|
}
|
||||||
|
|
||||||
prompt = "My favourite condiment is "
|
prompt = "My favourite condiment is "
|
||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||||
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")
|
model = MistralForCausalLM.from_pretrained(
|
||||||
|
"mistralai/Mistral-7B-v0.1", device_map={"": torch_device}, load_in_4bit=True
|
||||||
|
)
|
||||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||||
|
|
||||||
# greedy generation outputs
|
# greedy generation outputs
|
||||||
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
||||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
||||||
|
|
||||||
del model
|
del model
|
||||||
backend_empty_cache(torch_device)
|
backend_empty_cache(torch_device)
|
||||||
@@ -517,7 +546,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
input_ids = [1] + [306, 338] * 2048
|
input_ids = [1] + [306, 338] * 2048
|
||||||
model = MistralForCausalLM.from_pretrained(
|
model = MistralForCausalLM.from_pretrained(
|
||||||
"mistralai/Mistral-7B-v0.1",
|
"mistralai/Mistral-7B-v0.1",
|
||||||
device_map="auto",
|
device_map={"": torch_device},
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
attn_implementation="flash_attention_2",
|
attn_implementation="flash_attention_2",
|
||||||
)
|
)
|
||||||
@@ -544,9 +573,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
# An input with 4097 tokens that is above the size of the sliding window
|
# An input with 4097 tokens that is above the size of the sliding window
|
||||||
input_ids = [1] + [306, 338] * 2048
|
input_ids = [1] + [306, 338] * 2048
|
||||||
model = MistralForCausalLM.from_pretrained(
|
model = MistralForCausalLM.from_pretrained(
|
||||||
"mistralai/Mistral-7B-v0.1",
|
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.float16
|
||||||
device_map="auto",
|
|
||||||
attn_implementation="sdpa",
|
|
||||||
)
|
)
|
||||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||||
@@ -577,9 +604,10 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_speculative_generation(self):
|
def test_speculative_generation(self):
|
||||||
EXPECTED_TEXT_COMPLETION = (
|
EXPECTED_TEXT_COMPLETION = {
|
||||||
"My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs"
|
7: "My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs",
|
||||||
)
|
8: "My favourite condiment is 100% Sriracha. I love the heat, the sweetness, the tang",
|
||||||
|
}
|
||||||
prompt = "My favourite condiment is "
|
prompt = "My favourite condiment is "
|
||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||||
model = MistralForCausalLM.from_pretrained(
|
model = MistralForCausalLM.from_pretrained(
|
||||||
@@ -593,7 +621,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
|
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
|
||||||
)
|
)
|
||||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
||||||
|
|
||||||
del model
|
del model
|
||||||
backend_empty_cache(torch_device)
|
backend_empty_cache(torch_device)
|
||||||
|
|||||||
@@ -507,6 +507,16 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class MixtralIntegrationTest(unittest.TestCase):
|
class MixtralIntegrationTest(unittest.TestCase):
|
||||||
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
|
# Depending on the hardware we get different logits / generations
|
||||||
|
cuda_compute_capability_major_version = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if is_torch_available() and torch.cuda.is_available():
|
||||||
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_small_model_logits(self):
|
def test_small_model_logits(self):
|
||||||
@@ -518,18 +528,26 @@ class MixtralIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
||||||
# these logits have been obtained with the original megablocks impelmentation.
|
# these logits have been obtained with the original megablocks impelmentation.
|
||||||
EXPECTED_LOGITS = torch.Tensor(
|
EXPECTED_LOGITS = {
|
||||||
[[0.1670, 0.1620, 0.6094], [-0.8906, -0.1588, -0.6060], [0.1572, 0.1290, 0.7246]]
|
7: torch.Tensor([[0.1670, 0.1620, 0.6094], [-0.8906, -0.1588, -0.6060], [0.1572, 0.1290, 0.7246]]).to(
|
||||||
).to(torch_device)
|
torch_device
|
||||||
|
),
|
||||||
|
8: torch.Tensor([[0.1631, 0.1621, 0.6094], [-0.8906, -0.1621, -0.6094], [0.1572, 0.1270, 0.7227]]).to(
|
||||||
|
torch_device
|
||||||
|
),
|
||||||
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(dummy_input).logits
|
logits = model(dummy_input).logits
|
||||||
|
|
||||||
torch.testing.assert_close(logits[0, :3, :3].half(), EXPECTED_LOGITS, atol=1e-3, rtol=1e-3)
|
torch.testing.assert_close(
|
||||||
torch.testing.assert_close(logits[1, :3, :3].half(), EXPECTED_LOGITS, atol=1e-3, rtol=1e-3)
|
logits[0, :3, :3], EXPECTED_LOGITS[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
logits[1, :3, :3], EXPECTED_LOGITS[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
# @require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_small_model_logits_batched(self):
|
def test_small_model_logits_batched(self):
|
||||||
model_id = "hf-internal-testing/Mixtral-tiny"
|
model_id = "hf-internal-testing/Mixtral-tiny"
|
||||||
dummy_input = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 2, 3], [1, 1, 2, 3, 4, 5, 6, 7, 8]]).to(torch_device)
|
dummy_input = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 2, 3], [1, 1, 2, 3, 4, 5, 6, 7, 8]]).to(torch_device)
|
||||||
@@ -540,23 +558,48 @@ class MixtralIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
||||||
EXPECTED_LOGITS_LEFT = torch.Tensor(
|
EXPECTED_LOGITS_LEFT = {
|
||||||
|
7: torch.Tensor(
|
||||||
[[0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007]],
|
[[0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007]],
|
||||||
)
|
).to(torch_device),
|
||||||
|
8: torch.Tensor([[0.1914, 0.0508, 0.7188], [0.1953, 0.0510, 0.7227], [0.1973, 0.0562, 0.7148]]).to(
|
||||||
|
torch_device
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
# logits[0, -3:, -3:].half()
|
EXPECTED_LOGITS_LEFT_UNPADDED = {
|
||||||
EXPECTED_LOGITS_LEFT_UNPADDED = torch.Tensor(
|
7: torch.Tensor(
|
||||||
[[0.2212, 0.5200, -0.3816], [0.8213, -0.2313, 0.6069], [0.2664, -0.7090, 0.2468]],
|
[[0.2212, 0.5200, -0.3816], [0.8213, -0.2313, 0.6069], [0.2664, -0.7090, 0.2468]],
|
||||||
)
|
).to(torch_device),
|
||||||
|
8: torch.Tensor([[0.2217, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7109, 0.2461]]).to(
|
||||||
|
torch_device
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
# logits[1, -3:, -3:].half()
|
EXPECTED_LOGITS_RIGHT_UNPADDED = {
|
||||||
EXPECTED_LOGITS_RIGHT_UNPADDED = torch.Tensor(
|
7: torch.Tensor([[0.2205, 0.1232, -0.1611], [-0.3484, 0.3030, -1.0312], [0.0742, 0.7930, 0.7969]]).to(
|
||||||
[[0.2205, 0.1232, -0.1611], [-0.3484, 0.3030, -1.0312], [0.0742, 0.7930, 0.7969]]
|
torch_device
|
||||||
)
|
),
|
||||||
|
8: torch.Tensor([[0.2178, 0.1260, -0.1621], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]]).to(
|
||||||
|
torch_device
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(dummy_input, attention_mask=attention_mask).logits
|
logits = model(dummy_input, attention_mask=attention_mask).logits
|
||||||
|
|
||||||
torch.testing.assert_close(logits[0, :3, :3].half(), EXPECTED_LOGITS_LEFT, atol=1e-3, rtol=1e-3)
|
torch.testing.assert_close(
|
||||||
torch.testing.assert_close(logits[0, -3:, -3:].half(), EXPECTED_LOGITS_LEFT_UNPADDED, atol=1e-3, rtol=1e-3)
|
logits[0, :3, :3], EXPECTED_LOGITS_LEFT[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3
|
||||||
torch.testing.assert_close(logits[1, -3:, -3:].half(), EXPECTED_LOGITS_RIGHT_UNPADDED, atol=1e-3, rtol=1e-3)
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
logits[0, -3:, -3:],
|
||||||
|
EXPECTED_LOGITS_LEFT_UNPADDED[self.cuda_compute_capability_major_version],
|
||||||
|
atol=1e-3,
|
||||||
|
rtol=1e-3,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
logits[1, -3:, -3:],
|
||||||
|
EXPECTED_LOGITS_RIGHT_UNPADDED[self.cuda_compute_capability_major_version],
|
||||||
|
atol=1e-3,
|
||||||
|
rtol=1e-3,
|
||||||
|
)
|
||||||
|
|||||||
@@ -3339,3 +3339,21 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
|||||||
@unittest.skip("The model doesn't support fast init from base")
|
@unittest.skip("The model doesn't support fast init from base")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
||||||
|
)
|
||||||
|
def test_flash_attn_2_generate_padding_right(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
||||||
|
)
|
||||||
|
def test_flash_attn_2_inference(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
||||||
|
)
|
||||||
|
def test_flash_attn_2_inference_padding_right(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -3245,6 +3245,7 @@ class ModelTesterMixin:
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
|
@is_flaky
|
||||||
def test_flash_attn_2_inference_equivalence(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn_2:
|
||||||
@@ -3338,6 +3339,7 @@ class ModelTesterMixin:
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
|
@is_flaky
|
||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn_2:
|
||||||
@@ -3427,6 +3429,7 @@ class ModelTesterMixin:
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
|
@is_flaky
|
||||||
def test_flash_attn_2_generate_left_padding(self):
|
def test_flash_attn_2_generate_left_padding(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn_2:
|
||||||
@@ -3470,6 +3473,7 @@ class ModelTesterMixin:
|
|||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
|
@is_flaky
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
def test_flash_attn_2_generate_padding_right(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
@@ -3888,19 +3892,20 @@ class ModelTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn_2:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name]
|
dummy_input = inputs_dict[model.main_input_name]
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||||
|
batch_size = dummy_attention_mask.shape[0]
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
is_padding_right = dummy_attention_mask[:, -1].sum().item() != batch_size
|
||||||
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
|
|
||||||
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
# To avoid errors with padding_side=="right"
|
||||||
|
if is_padding_right:
|
||||||
|
dummy_attention_mask = torch.ones_like(dummy_input)
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
tmpdirname,
|
tmpdirname,
|
||||||
@@ -3916,6 +3921,9 @@ class ModelTesterMixin:
|
|||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
|
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||||
|
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
||||||
|
|
||||||
_ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids)
|
_ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids)
|
||||||
# with attention mask
|
# with attention mask
|
||||||
_ = model(
|
_ = model(
|
||||||
|
|||||||
Reference in New Issue
Block a user