From aa478567f824986e2ea0fab9739b409123824bef Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 10 Apr 2025 13:33:01 +0200 Subject: [PATCH] Allow rocm systems to run these tests (#37278) * Allow rocm systems to run these tests * Fix skipTest logic * Use get_device_properties to check system capabilities --- .../models/musicgen/test_modeling_musicgen.py | 14 +++++---- .../test_modeling_musicgen_melody.py | 14 +++++---- tests/test_modeling_common.py | 29 ++++++++++++------- 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 1df5ec6ab0..28801cd1e2 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -31,6 +31,7 @@ from transformers import ( T5Config, ) from transformers.testing_utils import ( + get_device_properties, is_torch_available, require_flash_attn, require_torch, @@ -1093,12 +1094,15 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") - torch.compiler.reset() - compute_capability = torch.cuda.get_device_capability() - major, _ = compute_capability - - if not torch.version.cuda or major < 8: + (device_type, major) = get_device_properties() + if device_type == "cuda" and major < 8: self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") + elif device_type == "rocm" and major < 9: + self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0") + else: + self.skipTest(reason="This test requires a Nvidia or AMD GPU") + + torch.compiler.reset() for model_class in self.all_model_classes: if not model_class._supports_sdpa: diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 7501d83717..bf441bb19e 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -30,6 +30,7 @@ from transformers import ( T5Config, ) from transformers.testing_utils import ( + get_device_properties, is_torch_available, is_torchaudio_available, require_flash_attn, @@ -1083,12 +1084,15 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") - torch.compiler.reset() - compute_capability = torch.cuda.get_device_capability() - major, _ = compute_capability - - if not torch.version.cuda or major < 8: + (device_type, major) = get_device_properties() + if device_type == "cuda" and major < 8: self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") + elif device_type == "rocm" and major < 9: + self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0") + else: + self.skipTest(reason="This test requires a Nvidia or AMD GPU") + + torch.compiler.reset() for model_class in self.all_model_classes: if not model_class._supports_sdpa: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 401b7df711..be65971c95 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -72,6 +72,7 @@ from transformers.models.auto.modeling_auto import ( ) from transformers.testing_utils import ( CaptureLogger, + get_device_properties, hub_retry, is_flaky, require_accelerate, @@ -3763,12 +3764,15 @@ class ModelTesterMixin: if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") - torch.compiler.reset() - compute_capability = torch.cuda.get_device_capability() - major, _ = compute_capability - - if not torch.version.cuda or major < 8: + (device_type, major) = get_device_properties() + if device_type == "cuda" and major < 8: self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") + elif device_type == "rocm" and major < 9: + self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0") + else: + self.skipTest(reason="This test requires a Nvidia or AMD GPU") + + torch.compiler.reset() for model_class in self.all_model_classes: if not model_class._supports_sdpa: @@ -3808,13 +3812,16 @@ class ModelTesterMixin: def test_sdpa_can_compile_dynamic(self): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") - torch.compiler.reset() - if "cuda" in torch_device: - compute_capability = torch.cuda.get_device_capability() - major, _ = compute_capability - if not torch.version.cuda or major < 8: - self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") + (device_type, major) = get_device_properties() + if device_type == "cuda" and major < 8: + self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") + elif device_type == "rocm" and major < 9: + self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0") + else: + self.skipTest(reason="This test requires a Nvidia or AMD GPU") + + torch.compiler.reset() for model_class in self.all_model_classes: if not model_class._supports_sdpa: