From 6e57a56987ff201747f5f01bbce3ed2c0fda1910 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 2 Mar 2022 10:49:05 +0100 Subject: [PATCH] Adding timestamps for CTC with LM in ASR pipeline. (#15863) * Adding timestamps for CTC with LM in ASR pipeline. * iRemove print. * Nit change. --- .../models/wav2vec2/tokenization_wav2vec2.py | 2 +- .../pipelines/automatic_speech_recognition.py | 57 +++++++++++-------- ..._pipelines_automatic_speech_recognition.py | 26 +++++++++ 3 files changed, 61 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index def404a065..e9fec60af8 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -353,7 +353,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): word = char last_state = state - if state == "WORD": + if last_state == "WORD": word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) return word_offsets diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 33b6cad814..a93569c5d2 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -313,8 +313,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): # Optional return types optional = {} - if return_timestamps and self.type != "ctc": + if return_timestamps and self.type == "seq2seq": raise ValueError("We cannot return_timestamps yet on non-ctc models !") + if return_timestamps == "char" and self.type == "ctc_with_lm": + raise ValueError("CTC with LM cannot return `char` timestamps, only `words`") final_items = [] key = "logits" if self.type == "ctc_with_lm" else "tokens" @@ -335,34 +337,43 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): if self.type == "ctc_with_lm": if decoder_kwargs is None: decoder_kwargs = {} - text = self.decoder.decode_beams(items, **decoder_kwargs)[0][0] + beams = self.decoder.decode_beams(items, **decoder_kwargs) + text = beams[0][0] + if return_timestamps: + # Simply cast from pyctcdecode format to wav2vec2 format to leverage + # pre-existing code later + chunk_offset = beams[0][2] + word_offsets = [] + for word, (start_offset, end_offset) in chunk_offset: + word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) else: skip_special_tokens = self.type != "ctc" text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens) if return_timestamps: - if return_timestamps == "char": - decoded = self.tokenizer.decode( - items, skip_special_tokens=skip_special_tokens, output_char_offsets=True + char_offsets = self.tokenizer.decode( + items, skip_special_tokens=skip_special_tokens, output_char_offsets=True + )["char_offsets"] + if return_timestamps == "word": + word_offsets = self.tokenizer._get_word_offsets( + char_offsets, self.tokenizer.replace_word_delimiter_char ) - elif return_timestamps == "word": - decoded = self.tokenizer.decode( - items, skip_special_tokens=skip_special_tokens, output_word_offsets=True - ) - chunks = [] - for item in decoded[f"{return_timestamps}_offsets"]: - start = ( - item["start_offset"] - * self.model.config.inputs_to_logits_ratio - / self.feature_extractor.sampling_rate - ) - stop = ( - item["end_offset"] - * self.model.config.inputs_to_logits_ratio - / self.feature_extractor.sampling_rate - ) - chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)}) - optional["chunks"] = chunks + + if return_timestamps: + if return_timestamps == "word": + offsets = word_offsets + else: + offsets = char_offsets + chunks = [] + for item in offsets: + start = item["start_offset"] * self.model.config.inputs_to_logits_ratio + start /= self.feature_extractor.sampling_rate + + stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio + stop /= self.feature_extractor.sampling_rate + + chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)}) + optional["chunks"] = chunks extra = defaultdict(list) for output in model_outputs: diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index e0a0352e86..e3dab51aab 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -188,6 +188,32 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel }, ) + speech_recognizer.type = "ctc_with_lm" + # Simple test with CTC with LM, chunking + timestamps + output = speech_recognizer(filename, chunk_length_s=2.0, return_timestamps="word") + self.assertEqual( + output, + { + "text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumajcri", + "chunks": [ + {"text": "y", "timestamp": (0.52, 0.54)}, + {"text": "en", "timestamp": (0.6, 0.68)}, + {"text": "las", "timestamp": (0.74, 0.84)}, + {"text": "ramas", "timestamp": (0.94, 1.24)}, + {"text": "medio", "timestamp": (1.32, 1.52)}, + {"text": "sumergidas", "timestamp": (1.56, 2.22)}, + {"text": "revoloteaban", "timestamp": (2.36, 3.0)}, + {"text": "algunos", "timestamp": (3.06, 3.38)}, + {"text": "pájaros", "timestamp": (3.46, 3.86)}, + {"text": "de", "timestamp": (3.92, 4.0)}, + {"text": "quimérico", "timestamp": (4.08, 4.6)}, + {"text": "y", "timestamp": (4.66, 4.68)}, + {"text": "legendario", "timestamp": (4.74, 5.26)}, + {"text": "plumajcri", "timestamp": (5.34, 5.74)}, + ], + }, + ) + @require_tf def test_small_model_tf(self): self.skipTest("Tensorflow not supported yet.")