[TESTS] ASR pipeline (#33925)

* fix whisper translation

* correct slow_unfinished_sequence test

* make fixup
This commit is contained in:
Yoach Lacombe
2024-10-10 17:31:22 +02:00
committed by GitHub
parent a37a06a20b
commit e7dfb917f8

View File

@@ -1212,7 +1212,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline( speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
) )
output_2 = speech_recognizer_2(ds[0]["audio"]) output_2 = speech_recognizer_2(ds[40]["audio"])
self.assertEqual(output, output_2) self.assertEqual(output, output_2)
# either use generate_kwargs or set the model's generation_config # either use generate_kwargs or set the model's generation_config
@@ -1224,7 +1224,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
generate_kwargs={"task": "transcribe", "language": "<|it|>"}, generate_kwargs={"task": "transcribe", "language": "<|it|>"},
) )
output_3 = speech_translator(ds[0]["audio"]) output_3 = speech_translator(ds[40]["audio"])
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."}) self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
@slow @slow
@@ -1896,15 +1896,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
model="vasista22/whisper-hindi-large-v2", model="vasista22/whisper-hindi-large-v2",
device=torch_device, device=torch_device,
) )
# Original model wasn't trained with timestamps and has incorrect generation config
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
# the audio is 4 seconds long # the audio is 4 seconds long
audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset") audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
# Original model wasn't trained with timestamps and has incorrect generation config
out = pipe( out = pipe(
audio, audio,
return_timestamps=True, return_timestamps=True,
generate_kwargs={"generation_config": GenerationConfig.from_pretrained("openai/whisper-large-v2")},
) )
self.assertEqual( self.assertEqual(
out, out,