From ad0d7d17451fea6457c9ee81898f7f64ad7ef848 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 25 Feb 2022 14:06:45 +0100 Subject: [PATCH] Adding the option to return_timestamps on pure CTC ASR models. (#15792) * Adding the option to return_timestamps on pure CTC ASR models. * Remove `math.prod` which was introduced in Python 3.8 * int are not floats. * Reworking the PR to support "char" vs "word" output. * Fixup! * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen * Update src/transformers/pipelines/automatic_speech_recognition.py Co-authored-by: Patrick von Platen * Quality. Co-authored-by: Patrick von Platen --- .../models/hubert/configuration_hubert.py | 5 +- .../models/sew/configuration_sew.py | 5 +- .../models/sew_d/configuration_sew_d.py | 5 +- .../unispeech/configuration_unispeech.py | 5 +- .../configuration_unispeech_sat.py | 5 +- .../models/wav2vec2/configuration_wav2vec2.py | 5 +- .../models/wavlm/configuration_wavlm.py | 5 +- .../pipelines/automatic_speech_recognition.py | 51 +++++- ..._pipelines_automatic_speech_recognition.py | 155 +++++++++++++++++- 9 files changed, 218 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py index df1bbc860f..9b104aa9c5 100644 --- a/src/transformers/models/hubert/configuration_hubert.py +++ b/src/transformers/models/hubert/configuration_hubert.py @@ -14,7 +14,8 @@ # limitations under the License. """ Hubert model configuration""" -import math +import functools +import operator from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -253,4 +254,4 @@ class HubertConfig(PretrainedConfig): @property def inputs_to_logits_ratio(self): - return math.prod(self.conv_stride) + return functools.reduce(operator.mul, self.conv_stride, 1) diff --git a/src/transformers/models/sew/configuration_sew.py b/src/transformers/models/sew/configuration_sew.py index b253f99a42..ad6a6afa69 100644 --- a/src/transformers/models/sew/configuration_sew.py +++ b/src/transformers/models/sew/configuration_sew.py @@ -14,7 +14,8 @@ # limitations under the License. """ SEW model configuration""" -import math +import functools +import operator from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -248,4 +249,4 @@ class SEWConfig(PretrainedConfig): @property def inputs_to_logits_ratio(self): - return math.prod(self.conv_stride) + return functools.reduce(operator.mul, self.conv_stride, 1) diff --git a/src/transformers/models/sew_d/configuration_sew_d.py b/src/transformers/models/sew_d/configuration_sew_d.py index a0f5fb2e60..d4cbf038df 100644 --- a/src/transformers/models/sew_d/configuration_sew_d.py +++ b/src/transformers/models/sew_d/configuration_sew_d.py @@ -14,7 +14,8 @@ # limitations under the License. """ SEW-D model configuration""" -import math +import functools +import operator from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -284,4 +285,4 @@ class SEWDConfig(PretrainedConfig): @property def inputs_to_logits_ratio(self): - return math.prod(self.conv_stride) + return functools.reduce(operator.mul, self.conv_stride, 1) diff --git a/src/transformers/models/unispeech/configuration_unispeech.py b/src/transformers/models/unispeech/configuration_unispeech.py index 05c42b2457..919a3b4824 100644 --- a/src/transformers/models/unispeech/configuration_unispeech.py +++ b/src/transformers/models/unispeech/configuration_unispeech.py @@ -14,7 +14,8 @@ # limitations under the License. """ UniSpeech model configuration""" -import math +import functools +import operator from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -294,4 +295,4 @@ class UniSpeechConfig(PretrainedConfig): @property def inputs_to_logits_ratio(self): - return math.prod(self.conv_stride) + return functools.reduce(operator.mul, self.conv_stride, 1) diff --git a/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py b/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py index d76978ea30..98fc160b5a 100644 --- a/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py @@ -14,7 +14,8 @@ # limitations under the License. """ UniSpeechSat model configuration""" -import math +import functools +import operator from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -311,4 +312,4 @@ class UniSpeechSatConfig(PretrainedConfig): @property def inputs_to_logits_ratio(self): - return math.prod(self.conv_stride) + return functools.reduce(operator.mul, self.conv_stride, 1) diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index 71c81cdc79..f675f6799f 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -14,7 +14,8 @@ # limitations under the License. """ Wav2Vec2 model configuration""" -import math +import functools +import operator from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -334,4 +335,4 @@ class Wav2Vec2Config(PretrainedConfig): @property def inputs_to_logits_ratio(self): - return math.prod(self.conv_stride) + return functools.reduce(operator.mul, self.conv_stride, 1) diff --git a/src/transformers/models/wavlm/configuration_wavlm.py b/src/transformers/models/wavlm/configuration_wavlm.py index 1c0c1f0d90..1a44fc10de 100644 --- a/src/transformers/models/wavlm/configuration_wavlm.py +++ b/src/transformers/models/wavlm/configuration_wavlm.py @@ -14,7 +14,8 @@ # limitations under the License. """ WavLM model configuration""" -import math +import functools +import operator from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -335,4 +336,4 @@ class WavLMConfig(PretrainedConfig): @property def inputs_to_logits_ratio(self): - return math.prod(self.conv_stride) + return functools.reduce(operator.mul, self.conv_stride, 1) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index df0c24a5a5..3552a23ce3 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -165,10 +165,23 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in decoding (but used at inference to provide more context to the model). Only use `stride` with CTC models. + return_timestamps (*optional*, `str`): + Only available for pure CTC models. If set to `"char"`, the pipeline will return `timestamps` along the + text for every character in the text. For instance if you get `[{"text": "h", "timestamps": (0.5,0.6), + {"text": "i", "timestamps": (0.7, .9)}]`, then it means the model predicts that the letter "h" was + pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return + `timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ", + "timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model + predicts that the word "hi" was pronounces before 0.5 and after 0.9 seconds. Return: `Dict`: A dictionary with the following keys: - - **text** (`str`) -- The recognized text. + - **text** (`str` ) -- The recognized text. + - **chunks** (*optional(, `List[Dict]`) + When using `return_timestamps`, the `chunks` will become a list containing all the various text + chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text": + "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing + `"".join(chunk["text"] for chunk in output["chunks"])`. """ return super().__call__(inputs, **kwargs) @@ -183,6 +196,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): postprocess_params = {} if "decoder_kwargs" in kwargs: postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"] + if "return_timestamps" in kwargs: + postprocess_params["return_timestamps"] = kwargs["return_timestamps"] return preprocess_params, {}, postprocess_params @@ -323,7 +338,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): extra = model_inputs return {"is_last": is_last, **out, **extra} - def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None): + def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None): + # Optional return types + optional = {} + + if return_timestamps and self.type != "ctc": + raise ValueError("We cannot return_timestamps yet on non-ctc models !") + if self.type == "ctc_with_lm": final_logits = [] for outputs in model_outputs: @@ -349,6 +370,30 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): tokens = tokens.squeeze(0) text = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + if return_timestamps: + if return_timestamps == "char": + decoded = self.tokenizer.decode( + tokens, skip_special_tokens=skip_special_tokens, output_char_offsets=True + ) + elif return_timestamps == "word": + decoded = self.tokenizer.decode( + tokens, skip_special_tokens=skip_special_tokens, output_word_offsets=True + ) + chunks = [] + for item in decoded[f"{return_timestamps}_offsets"]: + start = ( + item["start_offset"] + * self.model.config.inputs_to_logits_ratio + / self.feature_extractor.sampling_rate + ) + stop = ( + item["end_offset"] + * self.model.config.inputs_to_logits_ratio + / self.feature_extractor.sampling_rate + ) + chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)}) + optional["chunks"] = chunks + extra = defaultdict(list) for output in model_outputs: output.pop("tokens", None) @@ -357,4 +402,4 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): if k == "is_last": continue extra[k].append(v) - return {"text": text, **extra} + return {"text": text, **optional, **extra} diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 5e1adbc27d..1731c99428 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -82,15 +82,46 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel outputs = speech_recognizer(audio) self.assertEqual(outputs, {"text": ANY(str)}) + # Striding audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate} if speech_recognizer.type == "ctc": outputs = speech_recognizer(audio) self.assertEqual(outputs, {"text": ANY(str)}) + else: # Non CTC models cannot use striding. with self.assertRaises(ValueError): outputs = speech_recognizer(audio) + # Timestamps + audio = np.zeros((34000,)) + if speech_recognizer.type == "ctc": + outputs = speech_recognizer(audio, return_timestamps="char") + self.assertIsInstance(outputs["chunks"], list) + n = len(outputs["chunks"]) + self.assertEqual( + outputs, + { + "text": ANY(str), + "chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)], + }, + ) + + outputs = speech_recognizer(audio, return_timestamps="word") + self.assertIsInstance(outputs["chunks"], list) + n = len(outputs["chunks"]) + self.assertEqual( + outputs, + { + "text": ANY(str), + "chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)], + }, + ) + else: + # Non CTC models cannot use return_timestamps + with self.assertRaises(ValueError): + outputs = speech_recognizer(audio, return_timestamps="char") + @require_torch @slow def test_pt_defaults(self): @@ -302,6 +333,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") filename = ds[40]["file"] + output = speech_recognizer(filename) self.assertEqual(output, {"text": "a man said to the universe sir i exist"}) @@ -322,6 +354,49 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel self.assertEqual(output, [{"text": ANY(str)}]) self.assertEqual(output[0]["text"][:6], "ZBT ZC") + @require_torch + def test_return_timestamps_ctc_fast(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="hf-internal-testing/tiny-random-wav2vec2", + ) + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + # Take short audio to keep the test readable + audio = ds[40]["audio"]["array"][:800] + + output = speech_recognizer(audio, return_timestamps="char") + self.assertEqual( + output, + { + "text": "ZBT ZX G", + "chunks": [ + {"text": " ", "timestamp": (0.0, 0.012)}, + {"text": "Z", "timestamp": (0.012, 0.016)}, + {"text": "B", "timestamp": (0.016, 0.02)}, + {"text": "T", "timestamp": (0.02, 0.024)}, + {"text": " ", "timestamp": (0.024, 0.028)}, + {"text": "Z", "timestamp": (0.028, 0.032)}, + {"text": "X", "timestamp": (0.032, 0.036)}, + {"text": " ", "timestamp": (0.036, 0.04)}, + {"text": "G", "timestamp": (0.04, 0.044)}, + ], + }, + ) + + output = speech_recognizer(audio, return_timestamps="word") + self.assertEqual( + output, + { + "text": "ZBT ZX G", + "chunks": [ + {"text": "ZBT", "timestamp": (0.012, 0.024)}, + {"text": "ZX", "timestamp": (0.028, 0.036)}, + {"text": "G", "timestamp": (0.04, 0.044)}, + ], + }, + ) + @require_torch @require_pyctcdecode def test_chunking_fast_with_lm(self): @@ -399,7 +474,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel @require_torch @slow - def test_chunking(self): + def test_chunking_and_timestamps(self): model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h") feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") @@ -416,11 +491,79 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel audio = ds[40]["audio"]["array"] n_repeats = 10 - audio = np.tile(audio, n_repeats) - output = speech_recognizer([audio], batch_size=2) - expected_text = "A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats - expected = [{"text": expected_text.strip()}] - self.assertEqual(output, expected) + audio_tiled = np.tile(audio, n_repeats) + output = speech_recognizer([audio_tiled], batch_size=2) + self.assertEqual(output, [{"text": ("A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats).strip()}]) + + output = speech_recognizer(audio, return_timestamps="char") + self.assertEqual(audio.shape, (74_400,)) + self.assertEqual(speech_recognizer.feature_extractor.sampling_rate, 16_000) + # The audio is 74_400 / 16_000 = 4.65s long. + self.assertEqual( + output, + { + "text": "A MAN SAID TO THE UNIVERSE SIR I EXIST", + "chunks": [ + {"text": "A", "timestamp": (0.6, 0.62)}, + {"text": " ", "timestamp": (0.62, 0.66)}, + {"text": "M", "timestamp": (0.68, 0.7)}, + {"text": "A", "timestamp": (0.78, 0.8)}, + {"text": "N", "timestamp": (0.84, 0.86)}, + {"text": " ", "timestamp": (0.92, 0.98)}, + {"text": "S", "timestamp": (1.06, 1.08)}, + {"text": "A", "timestamp": (1.14, 1.16)}, + {"text": "I", "timestamp": (1.16, 1.18)}, + {"text": "D", "timestamp": (1.2, 1.24)}, + {"text": " ", "timestamp": (1.24, 1.28)}, + {"text": "T", "timestamp": (1.28, 1.32)}, + {"text": "O", "timestamp": (1.34, 1.36)}, + {"text": " ", "timestamp": (1.38, 1.42)}, + {"text": "T", "timestamp": (1.42, 1.44)}, + {"text": "H", "timestamp": (1.44, 1.46)}, + {"text": "E", "timestamp": (1.46, 1.5)}, + {"text": " ", "timestamp": (1.5, 1.56)}, + {"text": "U", "timestamp": (1.58, 1.62)}, + {"text": "N", "timestamp": (1.64, 1.68)}, + {"text": "I", "timestamp": (1.7, 1.72)}, + {"text": "V", "timestamp": (1.76, 1.78)}, + {"text": "E", "timestamp": (1.84, 1.86)}, + {"text": "R", "timestamp": (1.86, 1.9)}, + {"text": "S", "timestamp": (1.96, 1.98)}, + {"text": "E", "timestamp": (1.98, 2.02)}, + {"text": " ", "timestamp": (2.02, 2.06)}, + {"text": "S", "timestamp": (2.82, 2.86)}, + {"text": "I", "timestamp": (2.94, 2.96)}, + {"text": "R", "timestamp": (2.98, 3.02)}, + {"text": " ", "timestamp": (3.06, 3.12)}, + {"text": "I", "timestamp": (3.5, 3.52)}, + {"text": " ", "timestamp": (3.58, 3.6)}, + {"text": "E", "timestamp": (3.66, 3.68)}, + {"text": "X", "timestamp": (3.68, 3.7)}, + {"text": "I", "timestamp": (3.9, 3.92)}, + {"text": "S", "timestamp": (3.94, 3.96)}, + {"text": "T", "timestamp": (4.0, 4.02)}, + {"text": " ", "timestamp": (4.06, 4.1)}, + ], + }, + ) + output = speech_recognizer(audio, return_timestamps="word") + self.assertEqual( + output, + { + "text": "A MAN SAID TO THE UNIVERSE SIR I EXIST", + "chunks": [ + {"text": "A", "timestamp": (0.6, 0.62)}, + {"text": "MAN", "timestamp": (0.68, 0.86)}, + {"text": "SAID", "timestamp": (1.06, 1.24)}, + {"text": "TO", "timestamp": (1.28, 1.36)}, + {"text": "THE", "timestamp": (1.42, 1.5)}, + {"text": "UNIVERSE", "timestamp": (1.58, 2.02)}, + {"text": "SIR", "timestamp": (2.82, 3.02)}, + {"text": "I", "timestamp": (3.5, 3.52)}, + {"text": "EXIST", "timestamp": (3.66, 4.02)}, + ], + }, + ) @require_torch @slow