From b5016d5de76df66f77506831c932bdb2578f87b4 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 13 Aug 2024 18:29:57 +0800 Subject: [PATCH] fix tensors on different devices in `WhisperGenerationMixin` (#32316) * fix * enable on xpu * no manual remove * move to device * remove to * add move to --- .../pipelines/test_pipelines_automatic_speech_recognition.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 777319d346..abb07d831a 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1513,7 +1513,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): def test_whisper_prompted(self): processor = AutoProcessor.from_pretrained("openai/whisper-tiny") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") - model = model.to("cuda") + model = model.to(torch_device) pipe = pipeline( "automatic-speech-recognition", @@ -1523,7 +1523,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): max_new_tokens=128, chunk_length_s=30, batch_size=16, - device="cuda:0", ) dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") @@ -1531,7 +1530,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): # prompt the model to misspell "Mr Quilter" as "Mr Quillter" whisper_prompt = "Mr. Quillter." - prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt") + prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt").to(torch_device) unprompted_result = pipe(sample.copy())["text"] prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]