Fix tests in ASR pipeline (#33545)
This commit is contained in:
@@ -295,8 +295,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(output, {"text": ""})
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
audio = ds[40]["audio"]
|
||||
output = speech_recognizer(audio)
|
||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||
|
||||
@require_torch
|
||||
@@ -312,8 +312,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(output, {"text": ""})
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
audio = ds[40]["audio"]
|
||||
output = speech_recognizer(audio)
|
||||
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
|
||||
|
||||
@slow
|
||||
@@ -542,11 +542,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
framework="pt",
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
audio = ds[40]["audio"]
|
||||
output = speech_recognizer(audio)
|
||||
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
|
||||
|
||||
output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
|
||||
output = speech_recognizer([ds[40]["audio"]], chunk_length_s=5, batch_size=4)
|
||||
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])
|
||||
|
||||
@require_torch
|
||||
@@ -1014,8 +1014,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
audio = ds[40]["audio"]
|
||||
output = speech_recognizer(audio)
|
||||
self.assertEqual(output, {"text": 'Ein Mann sagte zum Universum : " Sir, ich existiert! "'})
|
||||
|
||||
@slow
|
||||
@@ -1032,13 +1032,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(output, {"text": ""})
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = asr(filename)
|
||||
audio = ds[40]["audio"]
|
||||
output = asr(audio)
|
||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||
|
||||
filename = ds[40]["file"]
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
data = Audio().encode_example(ds[40]["audio"])["bytes"]
|
||||
output = asr(data)
|
||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||
|
||||
@@ -1058,13 +1056,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(output, {"text": "(Applausi)"})
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = asr(filename)
|
||||
audio = ds[40]["audio"]
|
||||
output = asr(audio)
|
||||
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
|
||||
|
||||
filename = ds[40]["file"]
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
data = Audio().encode_example(ds[40]["audio"])["bytes"]
|
||||
output = asr(data)
|
||||
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
|
||||
|
||||
@@ -1078,13 +1074,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
framework="pt",
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = ds[0]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
audio = ds[0]["audio"]
|
||||
output = speech_recognizer(audio)
|
||||
self.assertEqual(
|
||||
output,
|
||||
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
|
||||
)
|
||||
output = speech_recognizer(filename, return_timestamps=True)
|
||||
output = speech_recognizer(ds[0]["audio"], return_timestamps=True)
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
@@ -1100,7 +1096,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
output = speech_recognizer(filename, return_timestamps="word")
|
||||
output = speech_recognizer(ds[0]["audio"], return_timestamps="word")
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
output,
|
||||
@@ -1135,7 +1131,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
|
||||
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
|
||||
):
|
||||
_ = speech_recognizer(filename, return_timestamps="char")
|
||||
_ = speech_recognizer(audio, return_timestamps="char")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@@ -1147,8 +1143,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
framework="pt",
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
audio = ds[40]["audio"]
|
||||
output = speech_recognizer(audio)
|
||||
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
|
||||
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
@@ -1158,7 +1154,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
)
|
||||
output_2 = speech_recognizer_2(filename)
|
||||
output_2 = speech_recognizer_2(ds[0]["audio"])
|
||||
self.assertEqual(output, output_2)
|
||||
|
||||
# either use generate_kwargs or set the model's generation_config
|
||||
@@ -1170,7 +1166,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
feature_extractor=feature_extractor,
|
||||
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
|
||||
)
|
||||
output_3 = speech_translator(filename)
|
||||
output_3 = speech_translator(ds[0]["audio"])
|
||||
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
|
||||
|
||||
@slow
|
||||
@@ -1182,10 +1178,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
framework="pt",
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = ds[0]["file"]
|
||||
audio = ds[0]["audio"]
|
||||
|
||||
# 1. English-only model compatible with no language argument
|
||||
output = speech_recognizer(filename)
|
||||
output = speech_recognizer(audio)
|
||||
self.assertEqual(
|
||||
output,
|
||||
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
|
||||
@@ -1197,7 +1193,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be multilingual, "
|
||||
"pass `is_multilingual=True` to generate, or update the generation config.",
|
||||
):
|
||||
_ = speech_recognizer(filename, generate_kwargs={"language": "en"})
|
||||
_ = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"})
|
||||
|
||||
# 3. Multilingual model accepts language argument
|
||||
speech_recognizer = pipeline(
|
||||
@@ -1205,7 +1201,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
model="openai/whisper-tiny",
|
||||
framework="pt",
|
||||
)
|
||||
output = speech_recognizer(filename, generate_kwargs={"language": "en"})
|
||||
output = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"})
|
||||
self.assertEqual(
|
||||
output,
|
||||
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
|
||||
@@ -1315,8 +1311,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
audio = ds[40]["audio"]
|
||||
output = speech_recognizer(audio)
|
||||
self.assertEqual(output, {"text": "A man said to the universe: “Sir, I exist."})
|
||||
|
||||
@slow
|
||||
@@ -1331,8 +1327,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
audio = ds[40]["audio"]
|
||||
output = speech_recognizer(audio)
|
||||
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})
|
||||
|
||||
@slow
|
||||
@@ -1348,9 +1344,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
|
||||
output = speech_recognizer(filename)
|
||||
audio = ds[40]["audio"]
|
||||
output = speech_recognizer(audio)
|
||||
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
|
||||
|
||||
@slow
|
||||
@@ -1561,6 +1556,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
feature_extractor=processor.feature_extractor,
|
||||
max_new_tokens=128,
|
||||
device=torch_device,
|
||||
return_timestamps=True, # to allow longform generation
|
||||
)
|
||||
|
||||
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
|
||||
|
||||
Reference in New Issue
Block a user