[ASR Pipeline] Clarify return timestamps (#25344)
* [ASR Pipeline] Clarify return timestamps * fix indentation * fix ctc check * fix ctc error message! * fix test * fix other test * add new tests * final comment
This commit is contained in:
@@ -156,8 +156,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
feature_extractor ([`SequenceFeatureExtractor`]):
|
||||
The feature extractor that will be used by the pipeline to encode waveform for the model.
|
||||
chunk_length_s (`float`, *optional*, defaults to 0):
|
||||
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). Only
|
||||
available for CTC models, e.g. [`Wav2Vec2ForCTC`].
|
||||
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -247,14 +246,29 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
|
||||
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
||||
inference to provide more context to the model). Only use `stride` with CTC models.
|
||||
return_timestamps (*optional*, `str`):
|
||||
Only available for pure CTC models. If set to `"char"`, the pipeline will return timestamps along the
|
||||
text for every character in the text. For instance if you get `[{"text": "h", "timestamp": (0.5, 0.6)},
|
||||
{"text": "i", "timestamp": (0.7, 0.9)}]`, then it means the model predicts that the letter "h" was
|
||||
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
|
||||
timestamps along the text for every word in the text. For instance if you get `[{"text": "hi ",
|
||||
"timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": (1.0, 1.5)}]`, then it means the model
|
||||
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
|
||||
return_timestamps (*optional*, `str` or `bool`):
|
||||
Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for
|
||||
other sequence-to-sequence models.
|
||||
|
||||
For CTC models, timestamps can take one of two formats:
|
||||
- `"char"`: the pipeline will return timestamps along the text for every character in the text. For
|
||||
instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7,
|
||||
0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before
|
||||
`0.6` seconds.
|
||||
- `"word"`: the pipeline will return timestamps along the text for every word in the text. For
|
||||
instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp":
|
||||
(1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and
|
||||
before `0.9` seconds.
|
||||
|
||||
For the Whisper model, timestamps can take one of two formats:
|
||||
- `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted
|
||||
through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps
|
||||
by inspecting the cross-attention weights.
|
||||
- `True`: the pipeline will return timestamps along the text for *segments* of words in the text.
|
||||
For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the
|
||||
model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds.
|
||||
Note that a segment of text refers to a sequence of one or more words, rather than individual
|
||||
words as with word-level timestamps.
|
||||
generate_kwargs (`dict`, *optional*):
|
||||
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||
complete overview of generate, check the [following
|
||||
@@ -264,7 +278,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
|
||||
Return:
|
||||
`Dict`: A dictionary with the following keys:
|
||||
- **text** (`str` ) -- The recognized text.
|
||||
- **text** (`str`): The recognized text.
|
||||
- **chunks** (*optional(, `List[Dict]`)
|
||||
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
||||
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
|
||||
@@ -308,6 +322,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if decoder_kwargs is not None:
|
||||
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
||||
if return_timestamps is not None:
|
||||
if self.type == "seq2seq" and return_timestamps:
|
||||
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
|
||||
if self.type == "ctc_with_lm" and return_timestamps != "word":
|
||||
raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
|
||||
if self.type == "ctc" and return_timestamps not in ["char", "word"]:
|
||||
raise ValueError(
|
||||
"CTC can either predict character (char) level timestamps, or word level timestamps."
|
||||
"Set `return_timestamps='char'` or `return_timestamps='word'` as required."
|
||||
)
|
||||
if self.type == "seq2seq_whisper" and return_timestamps == "char":
|
||||
raise ValueError(
|
||||
"Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
|
||||
"Use `return_timestamps='word'` or `return_timestamps=True` respectively."
|
||||
)
|
||||
forward_params["return_timestamps"] = return_timestamps
|
||||
postprocess_params["return_timestamps"] = return_timestamps
|
||||
if return_language is not None:
|
||||
@@ -497,13 +525,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
# Optional return types
|
||||
optional = {}
|
||||
|
||||
if return_timestamps and self.type == "seq2seq":
|
||||
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
|
||||
if return_timestamps == "char" and self.type == "ctc_with_lm":
|
||||
raise ValueError("CTC with LM cannot return `char` timestamps, only `word`")
|
||||
if return_timestamps == "char" and self.type == "seq2seq_whisper":
|
||||
raise ValueError("Whisper cannot return `char` timestamps, use `True` or `word` instead.")
|
||||
|
||||
if return_language is not None and self.type != "seq2seq_whisper":
|
||||
raise ValueError("Only whisper can return language for now.")
|
||||
|
||||
|
||||
@@ -136,7 +136,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
else:
|
||||
# Non CTC models cannot use return_timestamps
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
|
||||
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
|
||||
):
|
||||
outputs = speech_recognizer(audio, return_timestamps="char")
|
||||
|
||||
@@ -161,7 +161,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
|
||||
# Non CTC models cannot use return_timestamps
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
|
||||
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
|
||||
):
|
||||
_ = speech_recognizer(waveform, return_timestamps="char")
|
||||
|
||||
@@ -261,6 +261,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
],
|
||||
},
|
||||
)
|
||||
# CTC + LM models cannot use return_timestamps="char"
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "^CTC with LM can only predict word level timestamps, set `return_timestamps='word'`$"
|
||||
):
|
||||
_ = speech_recognizer(filename, return_timestamps="char")
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
@@ -750,6 +755,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Whisper can only predict segment level timestamps or word level, not character level
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
|
||||
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
|
||||
):
|
||||
_ = speech_recognizer(filename, return_timestamps="char")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@@ -1082,6 +1095,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
],
|
||||
},
|
||||
)
|
||||
# CTC models must specify return_timestamps type - cannot set `return_timestamps=True` blindly
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"^CTC can either predict character (char) level timestamps, or word level timestamps."
|
||||
"Set `return_timestamps='char'` or `return_timestamps='word'` as required.$",
|
||||
):
|
||||
_ = speech_recognizer(audio, return_timestamps=True)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user