Allow rocm systems to run these tests (#37278)
* Allow rocm systems to run these tests * Fix skipTest logic * Use get_device_properties to check system capabilities
This commit is contained in:
@@ -31,6 +31,7 @@ from transformers import (
|
|||||||
T5Config,
|
T5Config,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
get_device_properties,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -1093,12 +1094,15 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
torch.compiler.reset()
|
(device_type, major) = get_device_properties()
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
if device_type == "cuda" and major < 8:
|
||||||
major, _ = compute_capability
|
|
||||||
|
|
||||||
if not torch.version.cuda or major < 8:
|
|
||||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||||
|
elif device_type == "rocm" and major < 9:
|
||||||
|
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
|
||||||
|
else:
|
||||||
|
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
|
||||||
|
|
||||||
|
torch.compiler.reset()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_sdpa:
|
if not model_class._supports_sdpa:
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from transformers import (
|
|||||||
T5Config,
|
T5Config,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
get_device_properties,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@@ -1083,12 +1084,15 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
torch.compiler.reset()
|
(device_type, major) = get_device_properties()
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
if device_type == "cuda" and major < 8:
|
||||||
major, _ = compute_capability
|
|
||||||
|
|
||||||
if not torch.version.cuda or major < 8:
|
|
||||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||||
|
elif device_type == "rocm" and major < 9:
|
||||||
|
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
|
||||||
|
else:
|
||||||
|
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
|
||||||
|
|
||||||
|
torch.compiler.reset()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_sdpa:
|
if not model_class._supports_sdpa:
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ from transformers.models.auto.modeling_auto import (
|
|||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
|
get_device_properties,
|
||||||
hub_retry,
|
hub_retry,
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
@@ -3763,12 +3764,15 @@ class ModelTesterMixin:
|
|||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
torch.compiler.reset()
|
(device_type, major) = get_device_properties()
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
if device_type == "cuda" and major < 8:
|
||||||
major, _ = compute_capability
|
|
||||||
|
|
||||||
if not torch.version.cuda or major < 8:
|
|
||||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||||
|
elif device_type == "rocm" and major < 9:
|
||||||
|
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
|
||||||
|
else:
|
||||||
|
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
|
||||||
|
|
||||||
|
torch.compiler.reset()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_sdpa:
|
if not model_class._supports_sdpa:
|
||||||
@@ -3808,13 +3812,16 @@ class ModelTesterMixin:
|
|||||||
def test_sdpa_can_compile_dynamic(self):
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
torch.compiler.reset()
|
|
||||||
if "cuda" in torch_device:
|
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
|
||||||
major, _ = compute_capability
|
|
||||||
|
|
||||||
if not torch.version.cuda or major < 8:
|
(device_type, major) = get_device_properties()
|
||||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
if device_type == "cuda" and major < 8:
|
||||||
|
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||||
|
elif device_type == "rocm" and major < 9:
|
||||||
|
self.skipTest(reason="This test requires an AMD GPU with compute capability >= 9.0")
|
||||||
|
else:
|
||||||
|
self.skipTest(reason="This test requires a Nvidia or AMD GPU")
|
||||||
|
|
||||||
|
torch.compiler.reset()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_sdpa:
|
if not model_class._supports_sdpa:
|
||||||
|
|||||||
Reference in New Issue
Block a user