Fix SDPA tests (#28552)

* skip bf16 test if not supported by device

* fix

* fix bis

* use is_torch_bf16_available_on_device

* use is_torch_fp16_available_on_device

* fix & use public llama

* use 1b model

* fix flacky test

---------

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
fxmarty
2024-01-17 17:29:18 +01:00
committed by GitHub
parent d6ffe74dfa
commit 2c1eebc121
2 changed files with 19 additions and 8 deletions

View File

@@ -84,6 +84,8 @@ from transformers.utils import (
is_accelerate_available,
is_flax_available,
is_tf_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_torch_fx_available,
is_torch_sdpa_available,
)
@@ -3382,8 +3384,13 @@ class ModelTesterMixin:
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_device == "cpu" and torch_dtype == "float16":
self.skipTest("float16 not supported on cpu")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
@@ -3400,7 +3407,7 @@ class ModelTesterMixin:
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 1e-3,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
@@ -3412,7 +3419,7 @@ class ModelTesterMixin:
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 1e-3,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,