Allow rocm systems to run these tests (#37278)

* Allow rocm systems to run these tests

* Fix skipTest logic

* Use get_device_properties to check system capabilities
This commit is contained in:
ivarflakstad
2025-04-10 13:33:01 +02:00
committed by GitHub
parent ae5ce22664
commit aa478567f8
3 changed files with 36 additions and 21 deletions

View File

@@ -31,6 +31,7 @@ from transformers import (
T5Config, T5Config,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
get_device_properties,
is_torch_available, is_torch_available,
require_flash_attn, require_flash_attn,
require_torch, require_torch,
@@ -1093,12 +1094,15 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
if not self.has_attentions: if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions") self.skipTest(reason="Model architecture does not support attentions")
torch.compiler.reset() (device_type, major) = get_device_properties()
compute_capability = torch.cuda.get_device_capability() if device_type == "cuda" and major < 8:
major, _ = compute_capability
if not torch.version.cuda or major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif device_type == "rocm" and major < 9:
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
else:
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
torch.compiler.reset()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_sdpa: if not model_class._supports_sdpa:

View File

@@ -30,6 +30,7 @@ from transformers import (
T5Config, T5Config,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
get_device_properties,
is_torch_available, is_torch_available,
is_torchaudio_available, is_torchaudio_available,
require_flash_attn, require_flash_attn,
@@ -1083,12 +1084,15 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
if not self.has_attentions: if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions") self.skipTest(reason="Model architecture does not support attentions")
torch.compiler.reset() (device_type, major) = get_device_properties()
compute_capability = torch.cuda.get_device_capability() if device_type == "cuda" and major < 8:
major, _ = compute_capability
if not torch.version.cuda or major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif device_type == "rocm" and major < 9:
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
else:
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
torch.compiler.reset()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_sdpa: if not model_class._supports_sdpa:

View File

@@ -72,6 +72,7 @@ from transformers.models.auto.modeling_auto import (
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger, CaptureLogger,
get_device_properties,
hub_retry, hub_retry,
is_flaky, is_flaky,
require_accelerate, require_accelerate,
@@ -3763,12 +3764,15 @@ class ModelTesterMixin:
if not self.has_attentions: if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions") self.skipTest(reason="Model architecture does not support attentions")
torch.compiler.reset() (device_type, major) = get_device_properties()
compute_capability = torch.cuda.get_device_capability() if device_type == "cuda" and major < 8:
major, _ = compute_capability
if not torch.version.cuda or major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif device_type == "rocm" and major < 9:
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
else:
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
torch.compiler.reset()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_sdpa: if not model_class._supports_sdpa:
@@ -3808,13 +3812,16 @@ class ModelTesterMixin:
def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_compile_dynamic(self):
if not self.has_attentions: if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions") self.skipTest(reason="Model architecture does not support attentions")
torch.compiler.reset()
if "cuda" in torch_device:
compute_capability = torch.cuda.get_device_capability()
major, _ = compute_capability
if not torch.version.cuda or major < 8: (device_type, major) = get_device_properties()
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") if device_type == "cuda" and major < 8:
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
elif device_type == "rocm" and major < 9:
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
else:
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
torch.compiler.reset()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if not model_class._supports_sdpa: if not model_class._supports_sdpa: