From 8f9fa3b081403288bb2b3f97244053b5afe46084 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Fri, 16 Aug 2024 18:34:13 +0800 Subject: [PATCH] [tests] make test_sdpa_equivalence device-agnostic (#32520) * fix on xpu * [run_all] --- tests/models/gemma/test_modeling_gemma.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 831ce1dec6..b564d51216 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -27,6 +27,7 @@ from transformers.testing_utils import ( require_flash_attn, require_read_token, require_torch, + require_torch_accelerator, require_torch_gpu, require_torch_sdpa, slow, @@ -460,7 +461,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi self.skipTest(reason="Gemma flash attention does not support right padding") @require_torch_sdpa - @require_torch_gpu + @require_torch_accelerator @slow def test_sdpa_equivalence(self): for model_class in self.all_model_classes: