Fix SDPA tests (#28552)
* skip bf16 test if not supported by device * fix * fix bis * use is_torch_bf16_available_on_device * use is_torch_fp16_available_on_device * fix & use public llama * use 1b model * fix flacky test --------- Co-authored-by: Your Name <you@example.com>
This commit is contained in:
@@ -457,10 +457,10 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
"""
|
"""
|
||||||
max_new_tokens = 30
|
max_new_tokens = 30
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
tokenizer = LlamaTokenizer.from_pretrained("saibo/llama-1B")
|
||||||
|
|
||||||
model_sdpa = LlamaForCausalLM.from_pretrained(
|
model_sdpa = LlamaForCausalLM.from_pretrained(
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"saibo/llama-1B",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
@@ -468,7 +468,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||||
|
|
||||||
model_eager = LlamaForCausalLM.from_pretrained(
|
model_eager = LlamaForCausalLM.from_pretrained(
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"saibo/llama-1B",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
attn_implementation="eager",
|
attn_implementation="eager",
|
||||||
@@ -488,7 +488,11 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
if not has_sdpa:
|
if not has_sdpa:
|
||||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||||
|
|
||||||
texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"]
|
texts = [
|
||||||
|
"hi here's a longer context, getting longer and",
|
||||||
|
"Hello this is a very long sentence my friend, very long for real",
|
||||||
|
"Today I am in Paris and",
|
||||||
|
]
|
||||||
|
|
||||||
for padding_side in ["left", "right"]:
|
for padding_side in ["left", "right"]:
|
||||||
tokenizer.padding_side = padding_side
|
tokenizer.padding_side = padding_side
|
||||||
|
|||||||
@@ -84,6 +84,8 @@ from transformers.utils import (
|
|||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
|
is_torch_bf16_available_on_device,
|
||||||
|
is_torch_fp16_available_on_device,
|
||||||
is_torch_fx_available,
|
is_torch_fx_available,
|
||||||
is_torch_sdpa_available,
|
is_torch_sdpa_available,
|
||||||
)
|
)
|
||||||
@@ -3382,8 +3384,13 @@ class ModelTesterMixin:
|
|||||||
if not self.all_model_classes[0]._supports_sdpa:
|
if not self.all_model_classes[0]._supports_sdpa:
|
||||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||||
|
|
||||||
if torch_device == "cpu" and torch_dtype == "float16":
|
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
|
||||||
self.skipTest("float16 not supported on cpu")
|
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
|
||||||
|
|
||||||
|
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
|
||||||
|
self.skipTest(
|
||||||
|
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
|
||||||
|
)
|
||||||
|
|
||||||
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
|
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
|
||||||
if torch_dtype == "float16":
|
if torch_dtype == "float16":
|
||||||
@@ -3400,7 +3407,7 @@ class ModelTesterMixin:
|
|||||||
("cpu", True, torch.bfloat16): 1e-2,
|
("cpu", True, torch.bfloat16): 1e-2,
|
||||||
("cuda", False, torch.float32): 1e-6,
|
("cuda", False, torch.float32): 1e-6,
|
||||||
("cuda", False, torch.bfloat16): 1e-2,
|
("cuda", False, torch.bfloat16): 1e-2,
|
||||||
("cuda", False, torch.float16): 1e-3,
|
("cuda", False, torch.float16): 5e-3,
|
||||||
("cuda", True, torch.float32): 1e-6,
|
("cuda", True, torch.float32): 1e-6,
|
||||||
("cuda", True, torch.bfloat16): 1e-2,
|
("cuda", True, torch.bfloat16): 1e-2,
|
||||||
("cuda", True, torch.float16): 5e-3,
|
("cuda", True, torch.float16): 5e-3,
|
||||||
@@ -3412,7 +3419,7 @@ class ModelTesterMixin:
|
|||||||
("cpu", True, torch.bfloat16): 1e-2,
|
("cpu", True, torch.bfloat16): 1e-2,
|
||||||
("cuda", False, torch.float32): 1e-4,
|
("cuda", False, torch.float32): 1e-4,
|
||||||
("cuda", False, torch.bfloat16): 1e-2,
|
("cuda", False, torch.bfloat16): 1e-2,
|
||||||
("cuda", False, torch.float16): 1e-3,
|
("cuda", False, torch.float16): 5e-3,
|
||||||
("cuda", True, torch.float32): 1e-4,
|
("cuda", True, torch.float32): 1e-4,
|
||||||
("cuda", True, torch.bfloat16): 3e-2,
|
("cuda", True, torch.bfloat16): 3e-2,
|
||||||
("cuda", True, torch.float16): 5e-3,
|
("cuda", True, torch.float16): 5e-3,
|
||||||
|
|||||||
Reference in New Issue
Block a user