From e55b33ceb4b0ba3c8c11f20b6e8d6ca4b48246d4 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Mon, 19 Aug 2024 19:46:59 +0800 Subject: [PATCH] [tests] make `test_sdpa_can_compile_dynamic` device-agnostic (#32519) * enable * fix --- tests/test_modeling_common.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b5bad16a02..7cbc2f3e28 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -79,6 +79,7 @@ from transformers.testing_utils import ( require_read_token, require_safetensors, require_torch, + require_torch_accelerator, require_torch_gpu, require_torch_multi_accelerator, require_torch_multi_gpu, @@ -4105,17 +4106,17 @@ class ModelTesterMixin: _ = model(**inputs_dict) @require_torch_sdpa - @require_torch_gpu + @require_torch_accelerator @slow def test_sdpa_can_compile_dynamic(self): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") + if "cuda" in torch_device: + compute_capability = torch.cuda.get_device_capability() + major, _ = compute_capability - compute_capability = torch.cuda.get_device_capability() - major, _ = compute_capability - - if not torch.version.cuda or major < 8: - self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") + if not torch.version.cuda or major < 8: + self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") for model_class in self.all_model_classes: if not model_class._supports_sdpa: