[TESTS] ASR pipeline (#33925)
* fix whisper translation * correct slow_unfinished_sequence test * make fixup
This commit is contained in:
@@ -1212,7 +1212,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
|
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
|
||||||
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||||
)
|
)
|
||||||
output_2 = speech_recognizer_2(ds[0]["audio"])
|
output_2 = speech_recognizer_2(ds[40]["audio"])
|
||||||
self.assertEqual(output, output_2)
|
self.assertEqual(output, output_2)
|
||||||
|
|
||||||
# either use generate_kwargs or set the model's generation_config
|
# either use generate_kwargs or set the model's generation_config
|
||||||
@@ -1224,7 +1224,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
|
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
|
||||||
)
|
)
|
||||||
output_3 = speech_translator(ds[0]["audio"])
|
output_3 = speech_translator(ds[40]["audio"])
|
||||||
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
|
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -1896,15 +1896,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
model="vasista22/whisper-hindi-large-v2",
|
model="vasista22/whisper-hindi-large-v2",
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
# Original model wasn't trained with timestamps and has incorrect generation config
|
|
||||||
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
|
|
||||||
|
|
||||||
# the audio is 4 seconds long
|
# the audio is 4 seconds long
|
||||||
audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
|
audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
|
||||||
|
|
||||||
|
# Original model wasn't trained with timestamps and has incorrect generation config
|
||||||
out = pipe(
|
out = pipe(
|
||||||
audio,
|
audio,
|
||||||
return_timestamps=True,
|
return_timestamps=True,
|
||||||
|
generate_kwargs={"generation_config": GenerationConfig.from_pretrained("openai/whisper-large-v2")},
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
out,
|
out,
|
||||||
|
|||||||
Reference in New Issue
Block a user