fix tensors on different devices in WhisperGenerationMixin (#32316)
* fix * enable on xpu * no manual remove * move to device * remove to * add move to
This commit is contained in:
@@ -1513,7 +1513,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
def test_whisper_prompted(self):
|
def test_whisper_prompted(self):
|
||||||
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
|
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
|
||||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||||
model = model.to("cuda")
|
model = model.to(torch_device)
|
||||||
|
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
"automatic-speech-recognition",
|
"automatic-speech-recognition",
|
||||||
@@ -1523,7 +1523,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
max_new_tokens=128,
|
max_new_tokens=128,
|
||||||
chunk_length_s=30,
|
chunk_length_s=30,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
device="cuda:0",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
||||||
@@ -1531,7 +1530,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
|
# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
|
||||||
whisper_prompt = "Mr. Quillter."
|
whisper_prompt = "Mr. Quillter."
|
||||||
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt")
|
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
unprompted_result = pipe(sample.copy())["text"]
|
unprompted_result = pipe(sample.copy())["text"]
|
||||||
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]
|
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]
|
||||||
|
|||||||
Reference in New Issue
Block a user