From c7c2f08994b1cacd77ead61e1627b7017df13bf9 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Fri, 9 May 2025 03:19:47 +0800 Subject: [PATCH] make `test_speculative_decoding_non_distil` device-agnostic (#38010) * make device-agnostic * use condition --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --- tests/models/whisper/test_modeling_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 9ec71d635e..519aed3511 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2495,9 +2495,9 @@ class WhisperModelIntegrationTests(unittest.TestCase): self.assertTrue(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster") @slow - @require_torch_gpu + @require_torch_accelerator def test_speculative_decoding_non_distil(self): - torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + torch_dtype = torch.float16 if torch_device in ["cuda", "xpu"] else torch.float32 model_id = "openai/whisper-large-v2" model = WhisperForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True