From e064f081504ef935a0fef30d5ce7dce4c58bd38b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 1 Mar 2022 17:03:05 +0100 Subject: [PATCH] Add time stamps for wav2vec2 with lm (#15854) * [Wav2Vec2 With LM] add timestamps * correct * correct * Apply suggestions from code review * correct * Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py * make style * Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py Co-authored-by: Nicolas Patry * make style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Nicolas Patry Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../models/wav2vec2/tokenization_wav2vec2.py | 10 +- .../tokenization_wav2vec2_phoneme.py | 7 +- .../processing_wav2vec2_with_lm.py | 106 +++++++++++++++-- .../test_processor_wav2vec2_with_lm.py | 108 +++++++++++++++++- 4 files changed, 215 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 97c6801b75..def404a065 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -97,6 +97,8 @@ WAV2VEC2_KWARGS_DOCSTRING = r""" Whether or not to print more information and warnings. """ +ListOfDict = List[Dict[str, Union[int, str]]] + @dataclass class Wav2Vec2CTCTokenizerOutput(ModelOutput): @@ -106,18 +108,18 @@ class Wav2Vec2CTCTokenizerOutput(ModelOutput): Args: text (list of `str` or `str`): Decoded logits in text from. Usually the speech transcription. - char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`): + char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`): Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with produced text. - word_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`): + word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`): Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets can be used to compute time stamps for each word. """ text: Union[List[str], str] - char_offsets: List[Dict[str, Union[float, str]]] = None - word_offsets: List[Dict[str, Union[float, str]]] = None + char_offsets: Union[List[ListOfDict], ListOfDict] = None + word_offsets: Union[List[ListOfDict], ListOfDict] = None class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): diff --git a/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py b/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py index b705103278..d6f8df3d81 100644 --- a/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +++ b/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py @@ -66,6 +66,9 @@ PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize} +ListOfDict = List[Dict[str, Union[int, str]]] + + @dataclass class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput): """ @@ -74,14 +77,14 @@ class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput): Args: text (list of `str` or `str`): Decoded logits in text from. Usually the speech transcription. - char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`): + char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`): Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with produced text. """ text: Union[List[str], str] - char_offsets: List[Dict[str, Union[float, str]]] = None + char_offsets: Union[List[ListOfDict], ListOfDict] = None class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer): diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index ca59a948ff..f2828bbf46 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -19,7 +19,7 @@ import os from contextlib import contextmanager from dataclasses import dataclass from multiprocessing import get_context -from typing import TYPE_CHECKING, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union import numpy as np @@ -34,23 +34,30 @@ if TYPE_CHECKING: from ...tokenization_utils import PreTrainedTokenizerBase +ListOfDict = List[Dict[str, Union[int, str]]] + + @dataclass class Wav2Vec2DecoderWithLMOutput(ModelOutput): """ Output type of [`Wav2Vec2DecoderWithLM`], with transcription. Args: - text (list of `str`): + text (list of `str` or `str`): Decoded logits in text from. Usually the speech transcription. - logit_score (list of `float`): + logit_score (list of `float` or `float`): Total logit score of the beam associated with produced text. lm_score (list of `float`): Fused lm_score of the beam associated with produced text. + word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`): + Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets + can be used to compute time stamps for each word. """ text: Union[List[str], str] logit_score: Union[List[float], float] = None lm_score: Union[List[float], float] = None + word_offsets: Union[List[ListOfDict], ListOfDict] = None class Wav2Vec2ProcessorWithLM(ProcessorMixin): @@ -232,6 +239,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): beta: Optional[float] = None, unk_score_offset: Optional[float] = None, lm_score_boundary: Optional[bool] = None, + output_word_offsets: bool = False, ): """ Batch decode output logits to audio transcription with language model support. @@ -267,6 +275,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): Amount of log score offset for unknown tokens lm_score_boundary (`bool`, *optional*): Whether to have kenlm respect boundaries when scoring + output_word_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate + and model downsampling rate to compute the time-stamps of transcribed words. + + + + Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to + better understand how to make use of `output_word_offsets`. + [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched + output. + + Returns: [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. @@ -310,13 +330,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): pool.close() # extract text and scores - batch_texts, logit_scores, lm_scores = [], [], [] + batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], [] for d in decoded_beams: batch_texts.append(d[0][0]) logit_scores.append(d[0][-2]) lm_scores.append(d[0][-1]) - # more output features will be added in the future - return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores) + word_offsets.append([{"word": t[0], "start_offset": t[1][0], "end_offset": t[1][1]} for t in d[0][1]]) + + word_offsets = word_offsets if output_word_offsets else None + + return Wav2Vec2DecoderWithLMOutput( + text=batch_texts, logit_score=logit_scores, lm_score=lm_scores, word_offsets=word_offsets + ) def decode( self, @@ -330,6 +355,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): beta: Optional[float] = None, unk_score_offset: Optional[float] = None, lm_score_boundary: Optional[bool] = None, + output_word_offsets: bool = False, ): """ Decode output logits to audio transcription with language model support. @@ -357,11 +383,65 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): Amount of log score offset for unknown tokens lm_score_boundary (`bool`, *optional*): Whether to have kenlm respect boundaries when scoring + output_word_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate + and model downsampling rate to compute the time-stamps of transcribed words. + + + + Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to + better understand how to make use of `output_word_offsets`. + + Returns: [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. - """ + Example: + + ```python + >>> # Let's see how to retrieve time steps for a model + >>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC + >>> from datasets import load_dataset + >>> import datasets + >>> import torch + + >>> # import model, feature extractor, tokenizer + >>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm") + >>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm") + + >>> # load first sample of English common_voice + >>> dataset = load_dataset("common_voice", "en", split="train", streaming=True) + >>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000)) + >>> dataset_iter = iter(dataset) + >>> sample = next(dataset_iter) + + >>> # forward sample through model to get greedily predicted transcription ids + >>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values + >>> with torch.no_grad(): + ... logits = model(input_values).logits[0].cpu().numpy() + + >>> # retrieve word stamps (analogous commands for `output_char_offsets`) + >>> outputs = tokenizer.decode(logits, output_word_offsets=True) + >>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate + >>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate + + >>> word_offsets = [ + ... { + ... "word": d["word"], + ... "start_time": d["start_offset"] * time_offset, + ... "end_time": d["end_offset"] * time_offset, + ... } + ... for d in outputs.word_offsets + ... ] + >>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer: + >>> # https://huggingface.co/datasets/common_voice/viewer/en/train + >>> word_offset + >>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES', + >>> # 'start_time': 1.64, 'end_time': 1.88}, {'word': 'A', + >>> # 'start_time': 2.12, 'end_time': 2.14}, {'word': 'MILE', 'start_time': 2.26, 'end_time': 2.46}, ... + ```""" + from pyctcdecode.constants import ( DEFAULT_BEAM_WIDTH, DEFAULT_HOTWORD_WEIGHT, @@ -390,9 +470,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): hotword_weight=hotword_weight, ) + word_offsets = None + if output_word_offsets: + word_offsets = [ + {"word": word, "start_offset": start_offset, "end_offset": end_offset} + for word, (start_offset, end_offset) in decoded_beams[0][2] + ] + # more output features will be added in the future return Wav2Vec2DecoderWithLMOutput( - text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1] + text=decoded_beams[0][0], + logit_score=decoded_beams[0][-2], + lm_score=decoded_beams[0][-1], + word_offsets=word_offsets, ) @contextmanager diff --git a/tests/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py index 7ab65de523..ae1159dc9b 100644 --- a/tests/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py +++ b/tests/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py @@ -20,13 +20,15 @@ import unittest from multiprocessing import get_context from pathlib import Path +import datasets import numpy as np +from datasets import load_dataset from transformers import AutoProcessor -from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available +from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available, is_torch_available from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES -from transformers.testing_utils import require_pyctcdecode +from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list @@ -35,6 +37,10 @@ if is_pyctcdecode_available(): from huggingface_hub import snapshot_download from pyctcdecode import BeamSearchDecoderCTC from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM + from transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm import Wav2Vec2DecoderWithLMOutput + +if is_torch_available(): + from transformers import Wav2Vec2ForCTC @require_pyctcdecode @@ -350,3 +356,101 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): decoded_auto = processor_auto.batch_decode(logits) self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text) + + @staticmethod + def get_from_offsets(offsets, key): + retrieved_list = [d[key] for d in offsets] + return retrieved_list + + def test_offsets_integration_fast(self): + processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm") + logits = self._get_dummy_logits()[0] + + outputs = processor.decode(logits, output_word_offsets=True) + # check Wav2Vec2CTCTokenizerOutput keys for word + self.assertTrue(len(outputs.keys()), 2) + self.assertTrue("text" in outputs) + self.assertTrue("word_offsets" in outputs) + self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput)) + + self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text) + self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["", "", ""]) + self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 2, 4]) + self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [1, 3, 5]) + + def test_offsets_integration_fast_batch(self): + processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm") + logits = self._get_dummy_logits() + + outputs = processor.batch_decode(logits, output_word_offsets=True) + + # check Wav2Vec2CTCTokenizerOutput keys for word + self.assertTrue(len(outputs.keys()), 2) + self.assertTrue("text" in outputs) + self.assertTrue("word_offsets" in outputs) + self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput)) + + self.assertListEqual( + [" ".join(self.get_from_offsets(o, "word")) for o in outputs["word_offsets"]], outputs.text + ) + self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "word"), ["", "", ""]) + self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "start_offset"), [0, 2, 4]) + self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "end_offset"), [1, 3, 5]) + + @slow + @require_torch + @require_torchaudio + def test_word_time_stamp_integration(self): + import torch + + ds = load_dataset("common_voice", "en", split="train", streaming=True) + ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000)) + ds_iter = iter(ds) + sample = next(ds_iter) + + processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm") + model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm") + + # compare to filename `common_voice_en_100038.mp3` of dataset viewer on https://huggingface.co/datasets/common_voice/viewer/en/train + input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values + + with torch.no_grad(): + logits = model(input_values).logits.cpu().numpy() + + output = processor.decode(logits[0], output_word_offsets=True) + + time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate + word_time_stamps = [ + { + "start_time": d["start_offset"] * time_offset, + "end_time": d["end_offset"] * time_offset, + "word": d["word"], + } + for d in output["word_offsets"] + ] + + EXPECTED_TEXT = "WHY DOES A MILE SANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL" + + # output words + self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT) + self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), output.text) + + # output times + start_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "start_time")] + end_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "end_time")] + + # fmt: off + self.assertListEqual( + start_times, + [ + 1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66, + ], + ) + + self.assertListEqual( + end_times, + [ + 1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94, + ], + ) + # fmt: on