From dedd11160d28af01d6355bcac38aa0937eeed7a6 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Tue, 8 Aug 2023 10:16:00 +0100 Subject: [PATCH] [ASR Pipeline] Clarify return timestamps (#25344) * [ASR Pipeline] Clarify return timestamps * fix indentation * fix ctc check * fix ctc error message! * fix test * fix other test * add new tests * final comment --- .../pipelines/automatic_speech_recognition.py | 65 ++++++++++++------- ..._pipelines_automatic_speech_recognition.py | 24 ++++++- 2 files changed, 65 insertions(+), 24 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index f32aabe64f..1f2c202ac2 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -156,8 +156,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): feature_extractor ([`SequenceFeatureExtractor`]): The feature extractor that will be used by the pipeline to encode waveform for the model. chunk_length_s (`float`, *optional*, defaults to 0): - The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). Only - available for CTC models, e.g. [`Wav2Vec2ForCTC`]. + The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). @@ -247,14 +246,29 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in decoding (but used at inference to provide more context to the model). Only use `stride` with CTC models. - return_timestamps (*optional*, `str`): - Only available for pure CTC models. If set to `"char"`, the pipeline will return timestamps along the - text for every character in the text. For instance if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, - {"text": "i", "timestamp": (0.7, 0.9)}]`, then it means the model predicts that the letter "h" was - pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return - timestamps along the text for every word in the text. For instance if you get `[{"text": "hi ", - "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": (1.0, 1.5)}]`, then it means the model - predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds. + return_timestamps (*optional*, `str` or `bool`): + Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for + other sequence-to-sequence models. + + For CTC models, timestamps can take one of two formats: + - `"char"`: the pipeline will return timestamps along the text for every character in the text. For + instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7, + 0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before + `0.6` seconds. + - `"word"`: the pipeline will return timestamps along the text for every word in the text. For + instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": + (1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and + before `0.9` seconds. + + For the Whisper model, timestamps can take one of two formats: + - `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted + through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps + by inspecting the cross-attention weights. + - `True`: the pipeline will return timestamps along the text for *segments* of words in the text. + For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the + model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds. + Note that a segment of text refers to a sequence of one or more words, rather than individual + words as with word-level timestamps. generate_kwargs (`dict`, *optional*): The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a complete overview of generate, check the [following @@ -264,12 +278,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): Return: `Dict`: A dictionary with the following keys: - - **text** (`str` ) -- The recognized text. + - **text** (`str`): The recognized text. - **chunks** (*optional(, `List[Dict]`) - When using `return_timestamps`, the `chunks` will become a list containing all the various text - chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": - "there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing - `"".join(chunk["text"] for chunk in output["chunks"])`. + When using `return_timestamps`, the `chunks` will become a list containing all the various text + chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": + "there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing + `"".join(chunk["text"] for chunk in output["chunks"])`. """ return super().__call__(inputs, **kwargs) @@ -308,6 +322,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): if decoder_kwargs is not None: postprocess_params["decoder_kwargs"] = decoder_kwargs if return_timestamps is not None: + if self.type == "seq2seq" and return_timestamps: + raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!") + if self.type == "ctc_with_lm" and return_timestamps != "word": + raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`") + if self.type == "ctc" and return_timestamps not in ["char", "word"]: + raise ValueError( + "CTC can either predict character (char) level timestamps, or word level timestamps." + "Set `return_timestamps='char'` or `return_timestamps='word'` as required." + ) + if self.type == "seq2seq_whisper" and return_timestamps == "char": + raise ValueError( + "Whisper cannot return `char` timestamps, only word level or segment level timestamps. " + "Use `return_timestamps='word'` or `return_timestamps=True` respectively." + ) forward_params["return_timestamps"] = return_timestamps postprocess_params["return_timestamps"] = return_timestamps if return_language is not None: @@ -497,13 +525,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): # Optional return types optional = {} - if return_timestamps and self.type == "seq2seq": - raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !") - if return_timestamps == "char" and self.type == "ctc_with_lm": - raise ValueError("CTC with LM cannot return `char` timestamps, only `word`") - if return_timestamps == "char" and self.type == "seq2seq_whisper": - raise ValueError("Whisper cannot return `char` timestamps, use `True` or `word` instead.") - if return_language is not None and self.type != "seq2seq_whisper": raise ValueError("Only whisper can return language for now.") diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 2b77c556f4..2d43cdbc81 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -136,7 +136,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): else: # Non CTC models cannot use return_timestamps with self.assertRaisesRegex( - ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$" + ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$" ): outputs = speech_recognizer(audio, return_timestamps="char") @@ -161,7 +161,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): # Non CTC models cannot use return_timestamps with self.assertRaisesRegex( - ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$" + ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$" ): _ = speech_recognizer(waveform, return_timestamps="char") @@ -261,6 +261,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ], }, ) + # CTC + LM models cannot use return_timestamps="char" + with self.assertRaisesRegex( + ValueError, "^CTC with LM can only predict word level timestamps, set `return_timestamps='word'`$" + ): + _ = speech_recognizer(filename, return_timestamps="char") @require_tf def test_small_model_tf(self): @@ -750,6 +755,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ) # fmt: on + # Whisper can only predict segment level timestamps or word level, not character level + with self.assertRaisesRegex( + ValueError, + "^Whisper cannot return `char` timestamps, only word level or segment level timestamps. " + "Use `return_timestamps='word'` or `return_timestamps=True` respectively.$", + ): + _ = speech_recognizer(filename, return_timestamps="char") + @slow @require_torch @require_torchaudio @@ -1082,6 +1095,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ], }, ) + # CTC models must specify return_timestamps type - cannot set `return_timestamps=True` blindly + with self.assertRaisesRegex( + ValueError, + "^CTC can either predict character (char) level timestamps, or word level timestamps." + "Set `return_timestamps='char'` or `return_timestamps='word'` as required.$", + ): + _ = speech_recognizer(audio, return_timestamps=True) @require_torch @slow