enable 5 cases on XPU (#37507)

* make speecht5 test_batch_generation pass on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* enable 4 GlmIntegrationTest cases on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Update src/transformers/testing_utils.py

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Yao Matrix
2025-04-16 15:28:02 +08:00
committed by GitHub
parent 3165eb7c28
commit 5ab7a7c640
3 changed files with 16 additions and 2 deletions

View File

@@ -1026,6 +1026,19 @@ def require_torch_large_gpu(test_case, memory: float = 20):
)(test_case) )(test_case)
def require_torch_large_accelerator(test_case, memory: float = 20):
"""Decorator marking a test that requires an accelerator with more than `memory` GiB of memory."""
if torch_device != "cuda" and torch_device != "xpu":
return unittest.skip(reason=f"test requires a GPU or XPU with more than {memory} GiB of memory")(test_case)
torch_accelerator_module = getattr(torch, torch_device)
return unittest.skipUnless(
torch_accelerator_module.get_device_properties(0).total_memory / 1024**3 > memory,
f"test requires a GPU or XPU with more than {memory} GiB of memory",
)(test_case)
def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case): def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case):
""" """
Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled. Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled.

View File

@@ -22,7 +22,7 @@ from transformers.testing_utils import (
is_flaky, is_flaky,
require_flash_attn, require_flash_attn,
require_torch, require_torch,
require_torch_large_gpu, require_torch_large_accelerator,
require_torch_sdpa, require_torch_sdpa,
slow, slow,
torch_device, torch_device,
@@ -309,7 +309,7 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
@slow @slow
@require_torch_large_gpu @require_torch_large_accelerator
class GlmIntegrationTest(unittest.TestCase): class GlmIntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"] input_text = ["Hello I am doing", "Hi today"]
model_id = "THUDM/glm-4-9b" model_id = "THUDM/glm-4-9b"

View File

@@ -1223,6 +1223,7 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
"Mismatch in waveform between standalone and integrated vocoder for single instance generation.", "Mismatch in waveform between standalone and integrated vocoder for single instance generation.",
) )
@require_deterministic_for_xpu
def test_batch_generation(self): def test_batch_generation(self):
model = self.default_model model = self.default_model
processor = self.default_processor processor = self.default_processor