Fix tests in ASR pipeline (#33545)

This commit is contained in:
Yoach Lacombe
2024-09-18 16:25:45 +02:00
committed by GitHub
parent 4f1e9bae4e
commit f883827c0a

View File

@@ -295,8 +295,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, {"text": ""})
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
@require_torch
@@ -312,8 +312,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, {"text": ""})
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
@slow
@@ -542,11 +542,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
output = speech_recognizer([ds[40]["audio"]], chunk_length_s=5, batch_size=4)
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
@require_torch
@@ -1014,8 +1014,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": 'Ein Mann sagte zum Universum : " Sir, ich existiert! "'})
@slow
@@ -1032,13 +1032,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, {"text": ""})
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = asr(filename)
audio = ds[40]["audio"]
output = asr(audio)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
filename = ds[40]["file"]
with open(filename, "rb") as f:
data = f.read()
data = Audio().encode_example(ds[40]["audio"])["bytes"]
output = asr(data)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
@@ -1058,13 +1056,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, {"text": "(Applausi)"})
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = asr(filename)
audio = ds[40]["audio"]
output = asr(audio)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
filename = ds[40]["file"]
with open(filename, "rb") as f:
data = f.read()
data = Audio().encode_example(ds[40]["audio"])["bytes"]
output = asr(data)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
@@ -1078,13 +1074,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
output = speech_recognizer(filename)
audio = ds[0]["audio"]
output = speech_recognizer(audio)
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
)
output = speech_recognizer(filename, return_timestamps=True)
output = speech_recognizer(ds[0]["audio"], return_timestamps=True)
self.assertEqual(
output,
{
@@ -1100,7 +1096,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
},
)
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
output = speech_recognizer(filename, return_timestamps="word")
output = speech_recognizer(ds[0]["audio"], return_timestamps="word")
# fmt: off
self.assertEqual(
output,
@@ -1135,7 +1131,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
):
_ = speech_recognizer(filename, return_timestamps="char")
_ = speech_recognizer(audio, return_timestamps="char")
@slow
@require_torch
@@ -1147,8 +1143,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
@@ -1158,7 +1154,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_2 = speech_recognizer_2(filename)
output_2 = speech_recognizer_2(ds[0]["audio"])
self.assertEqual(output, output_2)
# either use generate_kwargs or set the model's generation_config
@@ -1170,7 +1166,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
feature_extractor=feature_extractor,
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
)
output_3 = speech_translator(filename)
output_3 = speech_translator(ds[0]["audio"])
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
@slow
@@ -1182,10 +1178,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
audio = ds[0]["audio"]
# 1. English-only model compatible with no language argument
output = speech_recognizer(filename)
output = speech_recognizer(audio)
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
@@ -1197,7 +1193,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be multilingual, "
"pass `is_multilingual=True` to generate, or update the generation config.",
):
_ = speech_recognizer(filename, generate_kwargs={"language": "en"})
_ = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"})
# 3. Multilingual model accepts language argument
speech_recognizer = pipeline(
@@ -1205,7 +1201,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
model="openai/whisper-tiny",
framework="pt",
)
output = speech_recognizer(filename, generate_kwargs={"language": "en"})
output = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"})
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
@@ -1315,8 +1311,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "A man said to the universe: “Sir, I exist."})
@slow
@@ -1331,8 +1327,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})
@slow
@@ -1348,9 +1344,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
@slow
@@ -1561,6 +1556,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
device=torch_device,
return_timestamps=True, # to allow longform generation
)
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]