[ASR Pipeline] Clarify return timestamps (#25344)
* [ASR Pipeline] Clarify return timestamps * fix indentation * fix ctc check * fix ctc error message! * fix test * fix other test * add new tests * final comment
This commit is contained in:
@@ -156,8 +156,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
feature_extractor ([`SequenceFeatureExtractor`]):
|
feature_extractor ([`SequenceFeatureExtractor`]):
|
||||||
The feature extractor that will be used by the pipeline to encode waveform for the model.
|
The feature extractor that will be used by the pipeline to encode waveform for the model.
|
||||||
chunk_length_s (`float`, *optional*, defaults to 0):
|
chunk_length_s (`float`, *optional*, defaults to 0):
|
||||||
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). Only
|
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
|
||||||
available for CTC models, e.g. [`Wav2Vec2ForCTC`].
|
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
@@ -247,14 +246,29 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
|
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
|
||||||
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
||||||
inference to provide more context to the model). Only use `stride` with CTC models.
|
inference to provide more context to the model). Only use `stride` with CTC models.
|
||||||
return_timestamps (*optional*, `str`):
|
return_timestamps (*optional*, `str` or `bool`):
|
||||||
Only available for pure CTC models. If set to `"char"`, the pipeline will return timestamps along the
|
Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for
|
||||||
text for every character in the text. For instance if you get `[{"text": "h", "timestamp": (0.5, 0.6)},
|
other sequence-to-sequence models.
|
||||||
{"text": "i", "timestamp": (0.7, 0.9)}]`, then it means the model predicts that the letter "h" was
|
|
||||||
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
|
For CTC models, timestamps can take one of two formats:
|
||||||
timestamps along the text for every word in the text. For instance if you get `[{"text": "hi ",
|
- `"char"`: the pipeline will return timestamps along the text for every character in the text. For
|
||||||
"timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": (1.0, 1.5)}]`, then it means the model
|
instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7,
|
||||||
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
|
0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before
|
||||||
|
`0.6` seconds.
|
||||||
|
- `"word"`: the pipeline will return timestamps along the text for every word in the text. For
|
||||||
|
instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp":
|
||||||
|
(1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and
|
||||||
|
before `0.9` seconds.
|
||||||
|
|
||||||
|
For the Whisper model, timestamps can take one of two formats:
|
||||||
|
- `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted
|
||||||
|
through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps
|
||||||
|
by inspecting the cross-attention weights.
|
||||||
|
- `True`: the pipeline will return timestamps along the text for *segments* of words in the text.
|
||||||
|
For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the
|
||||||
|
model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds.
|
||||||
|
Note that a segment of text refers to a sequence of one or more words, rather than individual
|
||||||
|
words as with word-level timestamps.
|
||||||
generate_kwargs (`dict`, *optional*):
|
generate_kwargs (`dict`, *optional*):
|
||||||
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||||
complete overview of generate, check the [following
|
complete overview of generate, check the [following
|
||||||
@@ -264,12 +278,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
Return:
|
Return:
|
||||||
`Dict`: A dictionary with the following keys:
|
`Dict`: A dictionary with the following keys:
|
||||||
- **text** (`str` ) -- The recognized text.
|
- **text** (`str`): The recognized text.
|
||||||
- **chunks** (*optional(, `List[Dict]`)
|
- **chunks** (*optional(, `List[Dict]`)
|
||||||
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
||||||
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
|
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
|
||||||
"there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
"there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
||||||
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
||||||
"""
|
"""
|
||||||
return super().__call__(inputs, **kwargs)
|
return super().__call__(inputs, **kwargs)
|
||||||
|
|
||||||
@@ -308,6 +322,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
if decoder_kwargs is not None:
|
if decoder_kwargs is not None:
|
||||||
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
||||||
if return_timestamps is not None:
|
if return_timestamps is not None:
|
||||||
|
if self.type == "seq2seq" and return_timestamps:
|
||||||
|
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
|
||||||
|
if self.type == "ctc_with_lm" and return_timestamps != "word":
|
||||||
|
raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
|
||||||
|
if self.type == "ctc" and return_timestamps not in ["char", "word"]:
|
||||||
|
raise ValueError(
|
||||||
|
"CTC can either predict character (char) level timestamps, or word level timestamps."
|
||||||
|
"Set `return_timestamps='char'` or `return_timestamps='word'` as required."
|
||||||
|
)
|
||||||
|
if self.type == "seq2seq_whisper" and return_timestamps == "char":
|
||||||
|
raise ValueError(
|
||||||
|
"Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
|
||||||
|
"Use `return_timestamps='word'` or `return_timestamps=True` respectively."
|
||||||
|
)
|
||||||
forward_params["return_timestamps"] = return_timestamps
|
forward_params["return_timestamps"] = return_timestamps
|
||||||
postprocess_params["return_timestamps"] = return_timestamps
|
postprocess_params["return_timestamps"] = return_timestamps
|
||||||
if return_language is not None:
|
if return_language is not None:
|
||||||
@@ -497,13 +525,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# Optional return types
|
# Optional return types
|
||||||
optional = {}
|
optional = {}
|
||||||
|
|
||||||
if return_timestamps and self.type == "seq2seq":
|
|
||||||
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
|
|
||||||
if return_timestamps == "char" and self.type == "ctc_with_lm":
|
|
||||||
raise ValueError("CTC with LM cannot return `char` timestamps, only `word`")
|
|
||||||
if return_timestamps == "char" and self.type == "seq2seq_whisper":
|
|
||||||
raise ValueError("Whisper cannot return `char` timestamps, use `True` or `word` instead.")
|
|
||||||
|
|
||||||
if return_language is not None and self.type != "seq2seq_whisper":
|
if return_language is not None and self.type != "seq2seq_whisper":
|
||||||
raise ValueError("Only whisper can return language for now.")
|
raise ValueError("Only whisper can return language for now.")
|
||||||
|
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
# Non CTC models cannot use return_timestamps
|
# Non CTC models cannot use return_timestamps
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
|
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
|
||||||
):
|
):
|
||||||
outputs = speech_recognizer(audio, return_timestamps="char")
|
outputs = speech_recognizer(audio, return_timestamps="char")
|
||||||
|
|
||||||
@@ -161,7 +161,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
# Non CTC models cannot use return_timestamps
|
# Non CTC models cannot use return_timestamps
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
|
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
|
||||||
):
|
):
|
||||||
_ = speech_recognizer(waveform, return_timestamps="char")
|
_ = speech_recognizer(waveform, return_timestamps="char")
|
||||||
|
|
||||||
@@ -261,6 +261,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# CTC + LM models cannot use return_timestamps="char"
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "^CTC with LM can only predict word level timestamps, set `return_timestamps='word'`$"
|
||||||
|
):
|
||||||
|
_ = speech_recognizer(filename, return_timestamps="char")
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
@@ -750,6 +755,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
# Whisper can only predict segment level timestamps or word level, not character level
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
|
||||||
|
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
|
||||||
|
):
|
||||||
|
_ = speech_recognizer(filename, return_timestamps="char")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
@@ -1082,6 +1095,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# CTC models must specify return_timestamps type - cannot set `return_timestamps=True` blindly
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"^CTC can either predict character (char) level timestamps, or word level timestamps."
|
||||||
|
"Set `return_timestamps='char'` or `return_timestamps='word'` as required.$",
|
||||||
|
):
|
||||||
|
_ = speech_recognizer(audio, return_timestamps=True)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
Reference in New Issue
Block a user