From ed53809ac55087f88da873c9a3d061e279555842 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Wed, 16 Apr 2025 17:23:56 +0800 Subject: [PATCH] enable 6 rt_detr_v2 cases on xpu (#37548) * enable 6 rt_detr_v2 cases on xpu Signed-off-by: YAO Matrix * fix style Signed-off-by: YAO Matrix --------- Signed-off-by: YAO Matrix Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --- tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py b/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py index e8af79ca7b..e9874f7c51 100644 --- a/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py +++ b/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py @@ -27,7 +27,13 @@ from transformers import ( is_torch_available, is_vision_available, ) -from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow, torch_device +from transformers.testing_utils import ( + require_torch, + require_torch_accelerator, + require_vision, + slow, + torch_device, +) from transformers.utils import cached_property from ...test_configuration_common import ConfigTester @@ -636,7 +642,7 @@ class RTDetrV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase self.assertTrue(not failed_cases, message) @parameterized.expand(["float32", "float16", "bfloat16"]) - @require_torch_gpu + @require_torch_accelerator @slow def test_inference_with_different_dtypes(self, torch_dtype_str): torch_dtype = { @@ -658,7 +664,7 @@ class RTDetrV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase _ = model(**self._prepare_for_class(inputs_dict, model_class)) @parameterized.expand(["float32", "float16", "bfloat16"]) - @require_torch_gpu + @require_torch_accelerator @slow def test_inference_equivalence_for_static_and_dynamic_anchors(self, torch_dtype_str): torch_dtype = {