Allow rocm systems to run these tests (#37278)
* Allow rocm systems to run these tests * Fix skipTest logic * Use get_device_properties to check system capabilities
This commit is contained in:
@@ -31,6 +31,7 @@ from transformers import (
|
||||
T5Config,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
get_device_properties,
|
||||
is_torch_available,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
@@ -1093,12 +1094,15 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
torch.compiler.reset()
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, _ = compute_capability
|
||||
|
||||
if not torch.version.cuda or major < 8:
|
||||
(device_type, major) = get_device_properties()
|
||||
if device_type == "cuda" and major < 8:
|
||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||
elif device_type == "rocm" and major < 9:
|
||||
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
|
||||
else:
|
||||
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
|
||||
|
||||
torch.compiler.reset()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
|
||||
@@ -30,6 +30,7 @@ from transformers import (
|
||||
T5Config,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
get_device_properties,
|
||||
is_torch_available,
|
||||
is_torchaudio_available,
|
||||
require_flash_attn,
|
||||
@@ -1083,12 +1084,15 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
torch.compiler.reset()
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, _ = compute_capability
|
||||
|
||||
if not torch.version.cuda or major < 8:
|
||||
(device_type, major) = get_device_properties()
|
||||
if device_type == "cuda" and major < 8:
|
||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||
elif device_type == "rocm" and major < 9:
|
||||
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
|
||||
else:
|
||||
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
|
||||
|
||||
torch.compiler.reset()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
|
||||
Reference in New Issue
Block a user