From 2faa09530bc5d29756bddfec12037c066cc85a02 Mon Sep 17 00:00:00 2001 From: Matthijs Hollemans Date: Tue, 30 May 2023 15:06:58 +0200 Subject: [PATCH] fix Whisper tests on GPU (#23753) * move input features to GPU * skip these tests because undefined behavior * unskip tests --- tests/models/whisper/test_modeling_whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 2be7f6884e..3eee5ad496 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1477,7 +1477,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") model.to(torch_device) input_speech = self._load_datasamples(4)[-1:] - input_features = processor(input_speech, return_tensors="pt").input_features + input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) output_without_prompt = model.generate(input_features) prompt_ids = processor.get_prompt_ids("Leighton") @@ -1494,7 +1494,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") model.to(torch_device) input_speech = self._load_datasamples(1) - input_features = processor(input_speech, return_tensors="pt").input_features + input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) task = "translate" language = "de" expected_tokens = [f"<|{task}|>", f"<|{language}|>"] @@ -1513,7 +1513,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") model.to(torch_device) input_speech = self._load_datasamples(1) - input_features = processor(input_speech, return_tensors="pt").input_features + input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) prompt = "test prompt" prompt_ids = processor.get_prompt_ids(prompt)