[ASR Pipeline] Clarify return timestamps (#25344)

* [ASR Pipeline] Clarify return timestamps

* fix indentation

* fix ctc check

* fix ctc error message!

* fix test

* fix other test

* add new tests

* final comment
This commit is contained in:
Sanchit Gandhi
2023-08-08 10:16:00 +01:00
committed by GitHub
parent 5ea2595ecd
commit dedd11160d
2 changed files with 65 additions and 24 deletions

View File

@@ -136,7 +136,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
else:
# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
):
outputs = speech_recognizer(audio, return_timestamps="char")
@@ -161,7 +161,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
):
_ = speech_recognizer(waveform, return_timestamps="char")
@@ -261,6 +261,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
],
},
)
# CTC + LM models cannot use return_timestamps="char"
with self.assertRaisesRegex(
ValueError, "^CTC with LM can only predict word level timestamps, set `return_timestamps='word'`$"
):
_ = speech_recognizer(filename, return_timestamps="char")
@require_tf
def test_small_model_tf(self):
@@ -750,6 +755,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
# fmt: on
# Whisper can only predict segment level timestamps or word level, not character level
with self.assertRaisesRegex(
ValueError,
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
):
_ = speech_recognizer(filename, return_timestamps="char")
@slow
@require_torch
@require_torchaudio
@@ -1082,6 +1095,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
],
},
)
# CTC models must specify return_timestamps type - cannot set `return_timestamps=True` blindly
with self.assertRaisesRegex(
ValueError,
"^CTC can either predict character (char) level timestamps, or word level timestamps."
"Set `return_timestamps='char'` or `return_timestamps='word'` as required.$",
):
_ = speech_recognizer(audio, return_timestamps=True)
@require_torch
@slow