fix tensors on different devices in WhisperGenerationMixin (#32316)

* fix

* enable on xpu

* no manual remove

* move to device

* remove to

* add move to
This commit is contained in:
Fanli Lin
2024-08-13 18:29:57 +08:00
committed by GitHub
parent a5a8291ad1
commit b5016d5de7

View File

@@ -1513,7 +1513,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
def test_whisper_prompted(self):
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model = model.to("cuda")
model = model.to(torch_device)
pipe = pipeline(
"automatic-speech-recognition",
@@ -1523,7 +1523,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
device="cuda:0",
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
@@ -1531,7 +1530,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
whisper_prompt = "Mr. Quillter."
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt")
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt").to(torch_device)
unprompted_result = pipe(sample.copy())["text"]
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]