From c44d3675c285278406722b0fa9eb7afff2a3d434 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 22 Feb 2022 19:26:44 +0100 Subject: [PATCH] Time stamps for CTC models (#15687) * [Wav2Vec2 Time Stamps] * Add first version * add word time stamps * Fix * save intermediate space * improve * [Finish CTC Tokenizer] * remove @ * remove @ * push * continue with phonemes * up * finish PR * up * add example * rename * finish * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * correct split * finalize Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/model_doc/wav2vec2.mdx | 2 + .../models/hubert/configuration_hubert.py | 6 + .../models/sew/configuration_sew.py | 6 + .../models/sew_d/configuration_sew_d.py | 6 + .../unispeech/configuration_unispeech.py | 6 + .../configuration_unispeech_sat.py | 6 + .../models/wav2vec2/configuration_wav2vec2.py | 6 + .../models/wav2vec2/tokenization_wav2vec2.py | 356 +++++++++++++++++- .../tokenization_wav2vec2_phoneme.py | 221 ++++++++++- .../models/wavlm/configuration_wavlm.py | 6 + tests/test_tokenization_wav2vec2.py | 221 +++++++++-- tests/test_tokenization_wav2vec2_phoneme.py | 86 ++++- 12 files changed, 858 insertions(+), 70 deletions(-) diff --git a/docs/source/model_doc/wav2vec2.mdx b/docs/source/model_doc/wav2vec2.mdx index 11d3444cc7..9b2f13ea45 100644 --- a/docs/source/model_doc/wav2vec2.mdx +++ b/docs/source/model_doc/wav2vec2.mdx @@ -45,6 +45,8 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv [[autodoc]] Wav2Vec2CTCTokenizer - __call__ - save_vocabulary + - decode + - batch_decode ## Wav2Vec2FeatureExtractor diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py index 703d0eddfb..df1bbc860f 100644 --- a/src/transformers/models/hubert/configuration_hubert.py +++ b/src/transformers/models/hubert/configuration_hubert.py @@ -14,6 +14,8 @@ # limitations under the License. """ Hubert model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -248,3 +250,7 @@ class HubertConfig(PretrainedConfig): # ctc loss self.ctc_loss_reduction = ctc_loss_reduction self.ctc_zero_infinity = ctc_zero_infinity + + @property + def inputs_to_logits_ratio(self): + return math.prod(self.conv_stride) diff --git a/src/transformers/models/sew/configuration_sew.py b/src/transformers/models/sew/configuration_sew.py index 9d0953ffc1..b253f99a42 100644 --- a/src/transformers/models/sew/configuration_sew.py +++ b/src/transformers/models/sew/configuration_sew.py @@ -14,6 +14,8 @@ # limitations under the License. """ SEW model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -243,3 +245,7 @@ class SEWConfig(PretrainedConfig): # sequence classification self.use_weighted_layer_sum = use_weighted_layer_sum self.classifier_proj_size = classifier_proj_size + + @property + def inputs_to_logits_ratio(self): + return math.prod(self.conv_stride) diff --git a/src/transformers/models/sew_d/configuration_sew_d.py b/src/transformers/models/sew_d/configuration_sew_d.py index d808da2d1a..a0f5fb2e60 100644 --- a/src/transformers/models/sew_d/configuration_sew_d.py +++ b/src/transformers/models/sew_d/configuration_sew_d.py @@ -14,6 +14,8 @@ # limitations under the License. """ SEW-D model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -279,3 +281,7 @@ class SEWDConfig(PretrainedConfig): # sequence classification self.use_weighted_layer_sum = use_weighted_layer_sum self.classifier_proj_size = classifier_proj_size + + @property + def inputs_to_logits_ratio(self): + return math.prod(self.conv_stride) diff --git a/src/transformers/models/unispeech/configuration_unispeech.py b/src/transformers/models/unispeech/configuration_unispeech.py index 996a27fa92..05c42b2457 100644 --- a/src/transformers/models/unispeech/configuration_unispeech.py +++ b/src/transformers/models/unispeech/configuration_unispeech.py @@ -14,6 +14,8 @@ # limitations under the License. """ UniSpeech model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -289,3 +291,7 @@ class UniSpeechConfig(PretrainedConfig): # pretraining loss self.replace_prob = replace_prob + + @property + def inputs_to_logits_ratio(self): + return math.prod(self.conv_stride) diff --git a/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py b/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py index d13260dcff..d76978ea30 100644 --- a/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py @@ -14,6 +14,8 @@ # limitations under the License. """ UniSpeechSat model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -306,3 +308,7 @@ class UniSpeechSatConfig(PretrainedConfig): self.tdnn_kernel = list(tdnn_kernel) self.tdnn_dilation = list(tdnn_dilation) self.xvector_output_dim = xvector_output_dim + + @property + def inputs_to_logits_ratio(self): + return math.prod(self.conv_stride) diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index 808fab2667..71c81cdc79 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -14,6 +14,8 @@ # limitations under the License. """ Wav2Vec2 model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -329,3 +331,7 @@ class Wav2Vec2Config(PretrainedConfig): self.tdnn_kernel = list(tdnn_kernel) self.tdnn_dilation = list(tdnn_dilation) self.xvector_output_dim = xvector_output_dim + + @property + def inputs_to_logits_ratio(self): + return math.prod(self.conv_stride) diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 4b22fbdda1..d070bcb795 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -18,12 +18,22 @@ import json import os import sys import warnings +from dataclasses import dataclass from itertools import groupby -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np -from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings +from ...file_utils import ( + ModelOutput, + PaddingStrategy, + TensorType, + add_end_docstrings, + is_flax_available, + is_tf_available, + is_torch_available, + to_py_obj, +) from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list from ...tokenization_utils_base import AddedToken, BatchEncoding from ...utils import logging @@ -32,6 +42,15 @@ from ...utils import logging logger = logging.get_logger(__name__) +if TYPE_CHECKING: + if is_torch_available(): + import torch + if is_tf_available(): + import tensorflow as tf + if is_flax_available(): + import jax.numpy as jnp # noqa: F401 + + VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", "tokenizer_config_file": "tokenizer_config.json", @@ -79,6 +98,28 @@ WAV2VEC2_KWARGS_DOCSTRING = r""" """ +@dataclass +class Wav2Vec2CTCTokenizerOutput(ModelOutput): + """ + Output type of [` Wav2Vec2CTCTokenizer`], with transcription. + + Args: + text (list of `str` or `str`): + Decoded logits in text from. Usually the speech transcription. + char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`): + Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char + offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with + produced text. + word_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`): + Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets + can be used to compute time stamps for each word. + """ + + text: Union[List[str], str] + char_offsets: List[Dict[str, Union[float, str]]] = None + word_offsets: List[Dict[str, Union[float, str]]] = None + + class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): """ @@ -121,6 +162,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): unk_token="", pad_token="", word_delimiter_token="|", + replace_word_delimiter_char=" ", do_lower_case=False, **kwargs ): @@ -131,12 +173,14 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): pad_token=pad_token, do_lower_case=do_lower_case, word_delimiter_token=word_delimiter_token, + replace_word_delimiter_char=replace_word_delimiter_char, **kwargs, ) self._word_delimiter_token = word_delimiter_token self.do_lower_case = do_lower_case + self.replace_word_delimiter_char = replace_word_delimiter_char with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) @@ -204,31 +248,106 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): return result def convert_tokens_to_string( - self, tokens: List[str], group_tokens: bool = True, spaces_between_special_tokens: bool = False - ) -> str: + self, + tokens: List[str], + group_tokens: bool = True, + spaces_between_special_tokens: bool = False, + output_char_offsets: bool = False, + output_word_offsets: bool = False, + ) -> Dict[str, Union[str, float]]: """ Converts a connectionist-temporal-classification (CTC) output tokens into a single string. """ # group same tokens into non-repeating tokens in CTC style decoding if group_tokens: - tokens = [token_group[0] for token_group in groupby(tokens)] + chars, char_repetitions = zip(*((token, len(list(group_iter))) for token, group_iter in groupby(tokens))) + else: + chars = tokens + char_repetitions = len(tokens) * [1] # filter self.pad_token which is used as CTC-blank token - filtered_tokens = list(filter(lambda token: token != self.pad_token, tokens)) - - if spaces_between_special_tokens: - join_token = " " - else: - join_token = "" + processed_chars = list(filter(lambda char: char != self.pad_token, chars)) # replace delimiter token - string = join_token.join( - [" " if token == self.word_delimiter_token else token for token in filtered_tokens] - ).strip() + processed_chars = [ + self.replace_word_delimiter_char if char == self.word_delimiter_token else char for char in processed_chars + ] + + # retrieve offsets + char_offsets = word_offsets = None + if output_char_offsets or output_word_offsets: + char_offsets = self._compute_offsets(char_repetitions, chars, self.pad_token) + + if len(char_offsets) != len(processed_chars): + raise ValueError( + f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}" + " have to be of the same length, but are: " + f"`len(offsets)`: {len(char_offsets)} and `len(processed_tokens)`:" + f" {len(processed_chars)}" + ) + + # set tokens to correct processed token + for i, char in enumerate(processed_chars): + char_offsets[i]["char"] = char + + # retrieve word offsets from character offsets + word_offsets = None + if output_word_offsets: + word_offsets = self._get_word_offsets(char_offsets, self.replace_word_delimiter_char) + + # join to string + join_char = " " if spaces_between_special_tokens else "" + string = join_char.join(processed_chars).strip() if self.do_lower_case: string = string.lower() - return string + + return {"text": string, "char_offsets": char_offsets, "word_offsets": word_offsets} + + @staticmethod + def _compute_offsets( + char_repetitions: List[int], chars: List[str], ctc_token: int + ) -> List[Dict[str, Union[str, int]]]: + end_indices = np.asarray(char_repetitions).cumsum() + start_indices = np.concatenate(([0], end_indices[:-1])) + + offsets = [ + {"char": t, "start_offset": s, "end_offset": e} for t, s, e in zip(chars, start_indices, end_indices) + ] + + # filter out CTC token + offsets = list(filter(lambda offsets: offsets["char"] != ctc_token, offsets)) + return offsets + + @staticmethod + def _get_word_offsets( + offsets: Dict[str, Union[str, float]], word_delimiter_char: str = " " + ) -> Dict[str, Union[str, float]]: + word_offsets = [] + final_offset_idx = len(offsets) - 1 + + for i, offset in enumerate(offsets): + # define previous, next and current char + char = offset["char"] + prev_char = offsets[i - 1]["char"] if i > 0 else None + next_char = offsets[i + 1]["char"] if i < final_offset_idx else None + + # derive whether word begins, ends and whether current char is in word + word_begin = (i == 0 and char != word_delimiter_char) or (prev_char == word_delimiter_char) + word_end = (i == final_offset_idx and char != word_delimiter_char) or (next_char == word_delimiter_char) + char_is_in_word = char != word_delimiter_char + + if word_begin: + word_offset = {"word": "", "start_offset": offset["start_offset"]} + + if word_end: + word_offset["end_offset"] = offset["end_offset"] + word_offsets.append(word_offset) + + if char_is_in_word: + word_offset["word"] += offset["char"] + + return word_offsets def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): if is_split_into_words: @@ -242,6 +361,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): clean_up_tokenization_spaces: bool = True, group_tokens: bool = True, spaces_between_special_tokens: bool = False, + output_word_offsets: Optional[bool] = False, + output_char_offsets: Optional[bool] = False, ) -> str: """ special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the @@ -256,16 +377,210 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): continue result.append(token) - text = self.convert_tokens_to_string( - result, group_tokens=group_tokens, spaces_between_special_tokens=spaces_between_special_tokens + string_output = self.convert_tokens_to_string( + result, + group_tokens=group_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + output_word_offsets=output_word_offsets, + output_char_offsets=output_char_offsets, ) + text = string_output["text"] + if clean_up_tokenization_spaces: - clean_text = self.clean_up_tokenization(text) - return clean_text + text = self.clean_up_tokenization(text) + + if output_word_offsets or output_char_offsets: + return Wav2Vec2CTCTokenizerOutput( + text=text, + char_offsets=string_output["char_offsets"], + word_offsets=string_output["word_offsets"], + ) else: return text + # overwritten from `tokenization_utils_base.py` because tokenizer can output + # `ModelOutput` which should not be a list for batched output and + # because we need docs for `output_char_offsets` here + def batch_decode( + self, + sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + output_char_offsets: bool = False, + output_word_offsets: bool = False, + **kwargs + ) -> List[str]: + """ + Convert a list of lists of token ids into a list of strings by calling decode. + + Args: + sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to clean up the tokenization spaces. + output_char_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output character offsets. Character offsets can be used in combination with the + sampling rate and model downsampling rate to compute the time-stamps of transcribed characters. + + + + Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better + understand how to make use of `output_word_offsets`. + [`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output. + + + + output_word_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate + and model downsampling rate to compute the time-stamps of transcribed words. + + + + Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better + understand how to make use of `output_word_offsets`. + [`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output. + + + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `List[str]` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded + sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when + `output_char_offsets == True` or `output_word_offsets == True`. + """ + batch_decoded = [ + self.decode( + seq, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + output_char_offsets=output_char_offsets, + output_word_offsets=output_word_offsets, + **kwargs, + ) + for seq in sequences + ] + if output_char_offsets or output_word_offsets: + # transform list of dicts to dict of lists + return Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in batch_decoded] for k in batch_decoded[0]}) + + return batch_decoded + + # overwritten from `tokenization_utils_base.py` because we need docs for `output_char_offsets` + # and `output_word_offsets` here + def decode( + self, + token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + output_char_offsets: bool = False, + output_word_offsets: bool = False, + **kwargs + ) -> str: + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + + Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to clean up the tokenization spaces. + output_char_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output character offsets. Character offsets can be used in combination with the + sampling rate and model downsampling rate to compute the time-stamps of transcribed characters. + + + + Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better + understand how to make use of `output_word_offsets`. + + + + output_word_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate + and model downsampling rate to compute the time-stamps of transcribed words. + + + + Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better + understand how to make use of `output_word_offsets`. + + + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `str` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded + sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when + `output_char_offsets == True` or `output_word_offsets == True`. + + Example: + + ```python + >>> # Let's see how to retrieve time steps for a model + >>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC + >>> from datasets import load_dataset + >>> import datasets + >>> import torch + + >>> # import model, feature extractor, tokenizer + >>> model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") + + >>> # load first sample of English common_voice + >>> dataset = load_dataset("common_voice", "en", split="train", streaming=True) + >>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000)) + >>> dataset_iter = iter(dataset) + >>> sample = next(dataset_iter) + + >>> # forward sample through model to get greedily predicted transcription ids + >>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values + >>> logits = model(input_values).logits[0] + >>> pred_ids = torch.argmax(logits, axis=-1) + + >>> # retrieve word stamps (analogous commands for `output_char_offsets`) + >>> outputs = tokenizer.decode(pred_ids, output_word_offsets=True) + >>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate + >>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate + + >>> word_offsets = [ + ... { + ... "word": d["word"], + ... "start_time": d["start_offset"] * time_offset, + ... "end_time": d["end_offset"] * time_offset, + ... } + ... for d in outputs.word_offsets + ... ] + >>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer: + >>> # https://huggingface.co/datasets/common_voice/viewer/en/train + >>> word_offset + >>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES', + >>> # 'start_time': 1.64, 'end_time': 1.90}, {'word': 'MILISANDRA', + >>> # 'start_time': 2.26, 'end_time': 2.9}, {'word': 'LOOK', 'start_time': 3.0, 'end_time': 3.16}, ... + ```""" + # Convert inputs to python lists + token_ids = to_py_obj(token_ids) + + return self._decode( + token_ids=token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + output_char_offsets=output_char_offsets, + output_word_offsets=output_word_offsets, + **kwargs, + ) + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") @@ -294,7 +609,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): Returns: `int`: The number of tokens actually added to the vocabulary. - Examples: + Example: ```python # Let's see how to increase the vocabulary of Bert model and tokenizer @@ -551,6 +866,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): if self.do_lower_case: string = string.lower() + return string def _decode( diff --git a/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py b/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py index 73811f167c..b705103278 100644 --- a/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +++ b/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py @@ -17,10 +17,20 @@ import json import os import sys +from dataclasses import dataclass from itertools import groupby -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from ...file_utils import requires_backends +import numpy as np + +from ...file_utils import ( + ModelOutput, + is_flax_available, + is_tf_available, + is_torch_available, + requires_backends, + to_py_obj, +) from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list from ...tokenization_utils_base import AddedToken from ...utils import logging @@ -29,6 +39,15 @@ from ...utils import logging logger = logging.get_logger(__name__) +if TYPE_CHECKING: + if is_torch_available(): + import torch + if is_tf_available(): + import tensorflow as tf + if is_flax_available(): + import jax.numpy as jnp # noqa: F401 + + VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", "tokenizer_config_file": "tokenizer_config.json", @@ -47,6 +66,24 @@ PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize} +@dataclass +class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput): + """ + Output type of [` Wav2Vec2PhonemeCTCTokenizer`], with transcription. + + Args: + text (list of `str` or `str`): + Decoded logits in text from. Usually the speech transcription. + char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`): + Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char + offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with + produced text. + """ + + text: Union[List[str], str] + char_offsets: List[Dict[str, Union[float, str]]] = None + + class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer): """ @@ -284,24 +321,69 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer): group_tokens: bool = True, spaces_between_special_tokens: bool = False, filter_word_delimiter_token: bool = True, + output_char_offsets: bool = False, ) -> str: """ Converts a connectionist-temporal-classification (CTC) output tokens into a single string. """ # group same tokens into non-repeating tokens in CTC style decoding if group_tokens: - tokens = [token_group[0] for token_group in groupby(tokens)] + chars, char_repetitions = zip(*((token, len(list(group_iter))) for token, group_iter in groupby(tokens))) + else: + chars = tokens + char_repetitions = len(tokens) * [1] # filter self.pad_token which is used as CTC-blank token - filtered_tokens = list(filter(lambda token: token != self.pad_token, tokens)) + processed_chars = list(filter(lambda char: char != self.pad_token, chars)) # also filter self.word_delimiter_token if not not if filter_word_delimiter_token and self.word_delimiter_token is not None: - filtered_tokens = list(filter(lambda token: token != self.word_delimiter_token, filtered_tokens)) + processed_chars = list(filter(lambda token: token != self.word_delimiter_token, processed_chars)) - string = " ".join(filtered_tokens).strip() + # retrieve offsets + char_offsets = None + if output_char_offsets: + word_delimiter_token_for_offsets = ( + self.word_delimiter_token if filter_word_delimiter_token is True else None + ) + char_offsets = self._compute_offsets( + char_repetitions, chars, self.pad_token, word_delimiter_token=word_delimiter_token_for_offsets + ) - return string + if len(char_offsets) != len(processed_chars): + raise ValueError( + f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}" + f" have to be of the same length, but are: `len(offsets)`: " + f"{len(char_offsets)} and `len(processed_tokens)`: {len(processed_chars)}" + ) + + # set tokens to correct processed token + for i, char in enumerate(processed_chars): + char_offsets[i]["char"] = char + + string = " ".join(processed_chars).strip() + + return {"text": string, "char_offsets": char_offsets} + + @staticmethod + def _compute_offsets( + char_repetitions: List[int], chars: List[str], ctc_token: int, word_delimiter_token: Optional[int] = None + ) -> List[Dict[str, Union[str, int]]]: + end_indices = np.asarray(char_repetitions).cumsum() + start_indices = np.concatenate(([0], end_indices[:-1])) + + offsets = [ + {"char": t, "start_offset": s, "end_offset": e} for t, s, e in zip(chars, start_indices, end_indices) + ] + + # filter out CTC token + offsets = list(filter(lambda offsets: offsets["char"] != ctc_token, offsets)) + + # filter out word delimiter token if necessary + if word_delimiter_token is not None: + offsets = list(filter(lambda offsets: offsets["char"] != word_delimiter_token, offsets)) + + return offsets def _decode( self, @@ -311,6 +393,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer): group_tokens: bool = True, filter_word_delimiter_token: bool = True, spaces_between_special_tokens: bool = False, + output_char_offsets: bool = False, ) -> str: """ special _decode function is needed for Wav2Vec2PhonemeTokenizer because added tokens should be treated exactly @@ -325,19 +408,137 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer): continue result.append(token) - text = self.convert_tokens_to_string( + string_output = self.convert_tokens_to_string( result, group_tokens=group_tokens, spaces_between_special_tokens=spaces_between_special_tokens, filter_word_delimiter_token=filter_word_delimiter_token, + output_char_offsets=output_char_offsets, ) + text = string_output["text"] + if clean_up_tokenization_spaces: - clean_text = self.clean_up_tokenization(text) - return clean_text + text = self.clean_up_tokenization(text) + + if output_char_offsets: + return Wav2Vec2PhonemeCTCTokenizerOutput(text=text, char_offsets=string_output["char_offsets"]) else: return text + # overwritten from `tokenization_utils_base.py` because we need docs for `output_char_offsets` here + def decode( + self, + token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + output_char_offsets: bool = False, + **kwargs + ) -> str: + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + + Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to clean up the tokenization spaces. + output_char_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output character offsets. Character offsets can be used in combination with the + sampling rate and model downsampling rate to compute the time-stamps of transcribed characters. + + + + Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better + understand how to make use of `output_word_offsets`. + [`~model.wav2vec2_phoneme.tokenization_wav2vec2_phoneme.batch_decode`] works the same way with + phonemes. + + + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `str` or [`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`]: The decoded + sentence. Will be a [`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`] + when `output_char_offsets == True`. + """ + # Convert inputs to python lists + token_ids = to_py_obj(token_ids) + + return self._decode( + token_ids=token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + output_char_offsets=output_char_offsets, + **kwargs, + ) + + # overwritten from `tokenization_utils_base.py` because tokenizer can output + # `ModelOutput` which should not be a list for batched output and because + # we need docs for `output_char_offsets` here + def batch_decode( + self, + sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + output_char_offsets: bool = False, + **kwargs + ) -> List[str]: + """ + Convert a list of lists of token ids into a list of strings by calling decode. + + Args: + sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to clean up the tokenization spaces. + output_char_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output character offsets. Character offsets can be used in combination with the + sampling rate and model downsampling rate to compute the time-stamps of transcribed characters. + + + + Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better + understand how to make use of `output_word_offsets`. + [`~model.wav2vec2_phoneme.tokenization_wav2vec2_phoneme.batch_decode`] works analogous with phonemes + and batched output. + + + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `List[str]` or [`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`]: The + decoded sentence. Will be a + [`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`] when + `output_char_offsets == True`. + """ + batch_decoded = [ + self.decode( + seq, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + output_char_offsets=output_char_offsets, + **kwargs, + ) + for seq in sequences + ] + if output_char_offsets: + # transform list of dicts to dict of lists + return Wav2Vec2PhonemeCTCTokenizerOutput({k: [d[k] for d in batch_decoded] for k in batch_decoded[0]}) + + return batch_decoded + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") diff --git a/src/transformers/models/wavlm/configuration_wavlm.py b/src/transformers/models/wavlm/configuration_wavlm.py index 84eb542a16..1c0c1f0d90 100644 --- a/src/transformers/models/wavlm/configuration_wavlm.py +++ b/src/transformers/models/wavlm/configuration_wavlm.py @@ -14,6 +14,8 @@ # limitations under the License. """ WavLM model configuration""" +import math + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -330,3 +332,7 @@ class WavLMConfig(PretrainedConfig): self.tdnn_kernel = list(tdnn_kernel) self.tdnn_dilation = list(tdnn_dilation) self.xvector_output_dim = xvector_output_dim + + @property + def inputs_to_logits_ratio(self): + return math.prod(self.conv_stride) diff --git a/tests/test_tokenization_wav2vec2.py b/tests/test_tokenization_wav2vec2.py index 9cc082dd00..ce278cd8f0 100644 --- a/tests/test_tokenization_wav2vec2.py +++ b/tests/test_tokenization_wav2vec2.py @@ -29,7 +29,7 @@ from transformers import ( Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer, ) -from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES +from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2CTCTokenizerOutput from transformers.testing_utils import require_torch, slow from .test_tokenization_common import TokenizerTesterMixin @@ -422,27 +422,16 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): def test_tokenizer_decode_special(self): tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h") + # fmt: off sample_ids = [ [11, 5, 15, tokenizer.pad_token_id, 15, 8, 98], [24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77], ] sample_ids_2 = [ [11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98], - [ - 24, - 22, - 5, - tokenizer.pad_token_id, - tokenizer.pad_token_id, - tokenizer.pad_token_id, - tokenizer.word_delimiter_token_id, - 24, - 22, - 5, - 77, - tokenizer.word_delimiter_token_id, - ], + [24, 22, 5, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.word_delimiter_token_id], ] + # fmt: on batch_tokens = tokenizer.batch_decode(sample_ids) batch_tokens_2 = tokenizer.batch_decode(sample_ids_2) @@ -454,27 +443,12 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): tokenizer.add_tokens(["!", "?"]) tokenizer.add_special_tokens({"cls_token": "$$$"}) + # fmt: off sample_ids = [ - [ - 11, - 5, - 15, - tokenizer.pad_token_id, - 15, - 8, - 98, - 32, - 32, - 33, - tokenizer.word_delimiter_token_id, - 32, - 32, - 33, - 34, - 34, - ], + [11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34], [24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34], ] + # fmt: on batch_tokens = tokenizer.batch_decode(sample_ids) self.assertEqual(batch_tokens, ["HELLO!?!?$$$", "BYE BYE$$$"]) @@ -499,6 +473,187 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True) self.assertEqual(sent, expected_sent) + @staticmethod + def get_from_offsets(offsets, key): + retrieved_list = [d[key] for d in offsets] + return retrieved_list + + def test_offsets(self): + tokenizer = self.get_tokenizer() + + # fmt: off + # HEEEEE||LLLLO => HE LLO + # 1H + 5E + 2| + 3L + 1 + 1L + 1O + 1 + sample_ids = [11, 5, 5, 5, 5, 5, 4, 4, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98] + # fmt: on + + outputs_char = tokenizer.decode(sample_ids, output_char_offsets=True) + # check Wav2Vec2CTCTokenizerOutput keys for char + self.assertTrue(len(outputs_char.keys()), 2) + self.assertTrue("text" in outputs_char) + self.assertTrue("char_offsets" in outputs_char) + self.assertTrue(isinstance(outputs_char, Wav2Vec2CTCTokenizerOutput)) + + outputs_word = tokenizer.decode(sample_ids, output_word_offsets=True) + # check Wav2Vec2CTCTokenizerOutput keys for word + self.assertTrue(len(outputs_word.keys()), 2) + self.assertTrue("text" in outputs_word) + self.assertTrue("word_offsets" in outputs_word) + self.assertTrue(isinstance(outputs_word, Wav2Vec2CTCTokenizerOutput)) + + outputs = tokenizer.decode(sample_ids, output_char_offsets=True, output_word_offsets=True) + # check Wav2Vec2CTCTokenizerOutput keys for both + self.assertTrue(len(outputs.keys()), 3) + self.assertTrue("text" in outputs) + self.assertTrue("char_offsets" in outputs) + self.assertTrue("word_offsets" in outputs) + self.assertTrue(isinstance(outputs, Wav2Vec2CTCTokenizerOutput)) + + # check that order of chars is correct and identical for both outputs + self.assertEqual("".join(self.get_from_offsets(outputs["char_offsets"], "char")), outputs.text) + self.assertEqual( + self.get_from_offsets(outputs["char_offsets"], "char"), ["H", "E", " ", "L", "L", "O", ""] + ) + self.assertListEqual( + self.get_from_offsets(outputs["char_offsets"], "char"), + self.get_from_offsets(outputs_char["char_offsets"], "char"), + ) + + # check that order of words is correct and identical to both outputs + self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text) + self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["HE", "LLO"]) + self.assertListEqual( + self.get_from_offsets(outputs["word_offsets"], "word"), + self.get_from_offsets(outputs_word["word_offsets"], "word"), + ) + + # check that offsets are actually correct for char + # 0 is H, 1 is E, 6 is | (" "), 8 is 1st L, 12 is 2nd L, 13 is O, 14 is + self.assertListEqual(self.get_from_offsets(outputs["char_offsets"], "start_offset"), [0, 1, 6, 8, 12, 13, 14]) + # 1 is H, 6 is E, 8 is | (" "), 11 is 1st L (note due to + # different begin of 2nd L), 13 is 2nd L, 14 is O, 15 is + self.assertListEqual(self.get_from_offsets(outputs["char_offsets"], "end_offset"), [1, 6, 8, 11, 13, 14, 15]) + + # check that offsets are actually correct for word + # H is at 1st position of first word, first L is at 8th position of second word + self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 8]) + # last E is at 6th position of first word, first L is at last (15th) position of second word + self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [6, 15]) + + def test_offsets_batch(self): + tokenizer = self.get_tokenizer() + + def check_list_tuples_equal(outputs_batch, outputs_list): + self.assertTrue(isinstance(outputs_batch, Wav2Vec2CTCTokenizerOutput)) + self.assertTrue(isinstance(outputs_list[0], Wav2Vec2CTCTokenizerOutput)) + + # transform list to ModelOutput + outputs_batch_2 = Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in outputs_list] for k in outputs_list[0]}) + + self.assertListEqual(outputs_batch["text"], outputs_batch_2["text"]) + + def recursive_check(list_or_dict_1, list_or_dict_2): + if isinstance(list_or_dict_1, list): + [recursive_check(l1, l2) for l1, l2 in zip(list_or_dict_1, list_or_dict_2)] + self.assertEqual(list_or_dict_1, list_or_dict_2) + + if "char_offsets" in outputs_batch: + recursive_check(outputs_batch["char_offsets"], outputs_batch_2["char_offsets"]) + + if "word_offsets" in outputs_batch: + recursive_check(outputs_batch["word_offsets"], outputs_batch_2["word_offsets"]) + + # fmt: off + sample_ids = [ + [11, 5, 15, tokenizer.pad_token_id, 15, 4, 8, 98, 32, 32, 32, 32, 4, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34], + [24, 22, 5, tokenizer.word_delimiter_token_id, tokenizer.word_delimiter_token_id, 24, 22, 22, 22, 4, 5, 77, tokenizer.pad_token_id, 22, 22, 4, 34, 34, 34, 34], + ] + # fmt: on + + # We assume that `decode` works as expected. All we will check now is + # the output type is correct and the output is identical to `decode` + + # char + outputs_char_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True) + outputs_char = [tokenizer.decode(ids, output_char_offsets=True) for ids in sample_ids] + check_list_tuples_equal(outputs_char_batch, outputs_char) + + # word + outputs_word_batch = tokenizer.batch_decode(sample_ids, output_word_offsets=True) + outputs_word = [tokenizer.decode(ids, output_word_offsets=True) for ids in sample_ids] + check_list_tuples_equal(outputs_word_batch, outputs_word) + + # both + outputs_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True, output_word_offsets=True) + outputs = [tokenizer.decode(ids, output_word_offsets=True, output_char_offsets=True) for ids in sample_ids] + check_list_tuples_equal(outputs_batch, outputs) + + def test_offsets_integration(self): + tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h") + # pred_ids correspond to the following code + # ``` + # from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC + # from datasets import load_dataset + # import datasets + # import torch + # model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h") + # feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") + # + # ds = load_dataset("common_voice", "en", split="train", streaming=True) + # ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000)) + # ds_iter = iter(ds) + # sample = next(ds_iter) + # + # input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values + # logits = model(input_values).logits + # pred_ids = torch.argmax(logits, axis=-1).cpu().tolist() + # ``` + # fmt: off + pred_ids = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 11, 0, 0, 0, 22, 0, 0, 4, 4, 4, 14, 0, 0, 0, 0, 0, 8, 8, 0, 5, 5, 0, 12, 0, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 10, 0, 0, 0, 15, 0, 0, 10, 0, 0, 0, 12, 0, 0, 0, 0, 0, 7, 0, 9, 0, 0, 14, 0, 0, 0, 13, 0, 7, 0, 0, 4, 4, 0, 15, 8, 8, 0, 0, 8, 0, 26, 0, 0, 4, 4, 0, 0, 15, 0, 0, 0, 0, 0, 0, 10, 0, 26, 5, 5, 0, 4, 4, 0, 0, 12, 11, 0, 0, 5, 4, 4, 4, 0, 18, 0, 0, 0, 7, 9, 9, 0, 6, 0, 12, 12, 4, 4, 0, 6, 0, 0, 8, 0, 4, 4, 4, 0, 19, 0, 0, 8, 9, 9, 0, 0, 0, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 16, 16, 0, 0, 17, 5, 5, 5, 0, 4, 4, 4, 0, 0, 29, 29, 0, 0, 0, 0, 8, 11, 0, 9, 9, 0, 0, 0, 4, 4, 0, 12, 12, 0, 0, 0, 9, 0, 0, 0, 0, 0, 8, 18, 0, 0, 0, 4, 4, 0, 0, 8, 9, 0, 4, 4, 0, 6, 11, 5, 0, 4, 4, 0, 13, 13, 0, 0, 0, 10, 0, 0, 25, 0, 0, 6, 0, 4, 4, 0, 0, 0, 0, 7, 0, 0, 23, 0, 0, 4, 4, 0, 0, 0, 6, 11, 0, 5, 4, 4, 18, 0, 0, 0, 0, 0, 0, 7, 15, 0, 0, 0, 15, 15, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + + # wav2vec2-base downsamples input audio by a factor of 320 + # sampling rate for wav2vec2-base is 16_000 + time_offset_wav2vec2_base = 320 / 16_000 + + expected_char_time_stamps_text = ['W', 'H', 'Y', ' ', 'D', 'O', 'E', 'S', ' ', 'M', 'I', 'L', 'I', 'S', 'A', 'N', 'D', 'R', 'A', ' ', 'L', 'O', 'O', 'K', ' ', 'L', 'I', 'K', 'E', ' ', 'S', 'H', 'E', ' ', 'W', 'A', 'N', 'T', 'S', ' ', 'T', 'O', ' ', 'C', 'O', 'N', 'S', 'U', 'M', 'E', ' ', 'J', 'O', 'H', 'N', ' ', 'S', 'N', 'O', 'W', ' ', 'O', 'N', ' ', 'T', 'H', 'E', ' ', 'R', 'I', 'V', 'T', ' ', 'A', 'P', ' ', 'T', 'H', 'E', ' ', 'W', 'A', 'L', 'L', ' '] + expected_char_time_stamps_start = [1.42, 1.44, 1.52, 1.58, 1.64, 1.76, 1.82, 1.88, 1.92, 2.26, 2.32, 2.4, 2.46, 2.54, 2.66, 2.7, 2.76, 2.84, 2.88, 2.94, 3.0, 3.02, 3.1, 3.14, 3.2, 3.28, 3.42, 3.46, 3.48, 3.54, 3.62, 3.64, 3.7, 3.72, 3.8, 3.88, 3.9, 3.96, 4.0, 4.04, 4.1, 4.16, 4.2, 4.28, 4.34, 4.36, 4.48, 4.66, 4.74, 4.76, 4.84, 4.94, 5.06, 5.08, 5.12, 5.22, 5.28, 5.38, 5.5, 5.52, 5.6, 5.68, 5.7, 5.74, 5.8, 5.82, 5.84, 5.88, 5.94, 6.04, 6.1, 6.16, 6.2, 6.32, 6.38, 6.44, 6.54, 6.56, 6.6, 6.62, 6.66, 6.8, 6.82, 6.9, 6.96] + expected_char_time_stamps_end = [1.44, 1.46, 1.54, 1.64, 1.66, 1.8, 1.86, 1.9, 2.06, 2.28, 2.34, 2.42, 2.48, 2.56, 2.68, 2.72, 2.78, 2.86, 2.9, 2.98, 3.02, 3.06, 3.12, 3.16, 3.24, 3.3, 3.44, 3.48, 3.52, 3.58, 3.64, 3.66, 3.72, 3.78, 3.82, 3.9, 3.94, 3.98, 4.04, 4.08, 4.12, 4.18, 4.26, 4.3, 4.36, 4.4, 4.52, 4.7, 4.76, 4.82, 4.9, 4.98, 5.08, 5.1, 5.16, 5.26, 5.32, 5.4, 5.52, 5.54, 5.64, 5.7, 5.72, 5.78, 5.82, 5.84, 5.86, 5.92, 5.98, 6.06, 6.12, 6.18, 6.24, 6.34, 6.4, 6.48, 6.56, 6.58, 6.62, 6.66, 6.68, 6.82, 6.84, 6.94, 7.02] + + expected_word_time_stamps_text = ['WHY', 'DOES', 'MILISANDRA', 'LOOK', 'LIKE', 'SHE', 'WANTS', 'TO', 'CONSUME', 'JOHN', 'SNOW', 'ON', 'THE', 'RIVT', 'AP', 'THE', 'WALL'] + expected_word_time_stamps_start = [1.42, 1.64, 2.26, 3.0, 3.28, 3.62, 3.8, 4.1, 4.28, 4.94, 5.28, 5.68, 5.8, 5.94, 6.32, 6.54, 6.66] + expected_word_time_stamps_end = [1.54, 1.9, 2.9, 3.16, 3.52, 3.72, 4.04, 4.18, 4.82, 5.16, 5.54, 5.72, 5.86, 6.18, 6.4, 6.62, 6.94] + # fmt: on + + output = tokenizer.batch_decode(pred_ids, output_char_offsets=True, output_word_offsets=True) + + char_offsets_text = self.get_from_offsets(output["char_offsets"][0], "char") + char_offsets_start = self.get_from_offsets(output["char_offsets"][0], "start_offset") + char_offsets_end = self.get_from_offsets(output["char_offsets"][0], "end_offset") + + word_offsets_text = self.get_from_offsets(output["word_offsets"][0], "word") + word_offsets_start = self.get_from_offsets(output["word_offsets"][0], "start_offset") + word_offsets_end = self.get_from_offsets(output["word_offsets"][0], "end_offset") + + # let's transform offsets to time stamps in seconds + char_time_stamps_start = [round(c * time_offset_wav2vec2_base, 2) for c in char_offsets_start] + char_time_stamps_end = [round(c * time_offset_wav2vec2_base, 2) for c in char_offsets_end] + + word_time_stamps_start = [round(w * time_offset_wav2vec2_base, 2) for w in word_offsets_start] + word_time_stamps_end = [round(w * time_offset_wav2vec2_base, 2) for w in word_offsets_end] + + # NOTE: you can verify the above results by checking out the dataset viewer + # on https://huggingface.co/datasets/common_voice/viewer/en/train and + # downloading / playing the sample `common_voice_en_100038.mp3`. As + # you can hear the time-stamps match more or less + + self.assertListEqual(expected_char_time_stamps_text, char_offsets_text) + self.assertListEqual(expected_char_time_stamps_start, char_time_stamps_start) + self.assertListEqual(expected_char_time_stamps_end, char_time_stamps_end) + + self.assertListEqual(expected_word_time_stamps_text, word_offsets_text) + self.assertListEqual(expected_word_time_stamps_start, word_time_stamps_start) + self.assertListEqual(expected_word_time_stamps_end, word_time_stamps_end) + def test_pretrained_model_lists(self): # Wav2Vec2Model has no max model length => no testing pass diff --git a/tests/test_tokenization_wav2vec2_phoneme.py b/tests/test_tokenization_wav2vec2_phoneme.py index 7c479d0733..e72ffce1f6 100644 --- a/tests/test_tokenization_wav2vec2_phoneme.py +++ b/tests/test_tokenization_wav2vec2_phoneme.py @@ -20,6 +20,7 @@ from typing import Tuple from transformers import Wav2Vec2PhonemeCTCTokenizer from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES +from transformers.models.wav2vec2_phoneme.tokenization_wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizerOutput from transformers.testing_utils import require_phonemizer from .test_tokenization_common import TokenizerTesterMixin @@ -248,23 +249,94 @@ class Wav2Vec2PhonemeCTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): batch_tokens = tokenizer.batch_decode(sample_ids) self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ!?!? $$$", "j ð s j ð s oːɹ $$$"]) - # overwrite common test + @staticmethod + def get_from_offsets(offsets, key): + retrieved_list = [d[key] for d in offsets] + return retrieved_list + + def test_offsets(self): + tokenizer = self.get_tokenizer(word_delimiter_token="|") + tokenizer.add_tokens("|") + + # fmt: off + # ksssɾɾ|ɾɾɾɾ|ɾlll|ɭʲ -> k s ɾ ɾ | ɾ l | ɭʲ" + sample_ids = [11, 5, 5, 5, 15, 15, tokenizer.pad_token_id, 15, 15, tokenizer.word_delimiter_token_id, tokenizer.pad_token_id, 15, 8, 8, 8, tokenizer.word_delimiter_token_id, 98] + # fmt: on + + outputs = tokenizer.decode(sample_ids, output_char_offsets=True, filter_word_delimiter_token=False) + # check Wav2Vec2CTCTokenizerOutput keys for char + self.assertTrue(len(outputs.keys()), 2) + self.assertTrue("text" in outputs) + self.assertTrue("char_offsets" in outputs) + self.assertTrue(isinstance(outputs, Wav2Vec2PhonemeCTCTokenizerOutput)) + + # check that order of chars is correct and identical for both outputs + self.assertEqual(" ".join(self.get_from_offsets(outputs["char_offsets"], "char")), outputs.text) + self.assertListEqual( + self.get_from_offsets(outputs["char_offsets"], "char"), ["k", "s", "ɾ", "ɾ", "|", "ɾ", "l", "|", "ɭʲ"] + ) + + # check that offsets are actually correct for char + # 0-1 is 11, 1-4 is 5, 4-6 is first 15, 6-7 is (thus not shown), 7-9 is second 15, 9-10 is word_delimiter_token, + # 10-11 is (thus not shown), 11-12 is third 15, 12-15 is 8, 15-16 is word_delimiter_token, 16-17 is 98 + self.assertListEqual( + self.get_from_offsets(outputs["char_offsets"], "start_offset"), [0, 1, 4, 7, 9, 11, 12, 15, 16] + ) + self.assertListEqual( + self.get_from_offsets(outputs["char_offsets"], "end_offset"), [1, 4, 6, 9, 10, 12, 15, 16, 17] + ) + + def test_offsets_batch(self): + tokenizer = self.get_tokenizer(word_delimiter_token="|") + + def check_list_tuples_equal(outputs_batch, outputs_list): + self.assertTrue(isinstance(outputs_batch, Wav2Vec2PhonemeCTCTokenizerOutput)) + self.assertTrue(isinstance(outputs_list[0], Wav2Vec2PhonemeCTCTokenizerOutput)) + + # transform list to ModelOutput + outputs_batch_2 = Wav2Vec2PhonemeCTCTokenizerOutput( + {k: [d[k] for d in outputs_list] for k in outputs_list[0]} + ) + + self.assertListEqual(outputs_batch["text"], outputs_batch_2["text"]) + + def recursive_check(list_or_dict_1, list_or_dict_2): + if isinstance(list_or_dict_1, list): + [recursive_check(l1, l2) for l1, l2 in zip(list_or_dict_1, list_or_dict_2)] + self.assertEqual(list_or_dict_1, list_or_dict_2) + + if "char_offsets" in outputs_batch: + recursive_check(outputs_batch["char_offsets"], outputs_batch_2["char_offsets"]) + + # fmt: off + sample_ids = [ + [11, 5, 15, tokenizer.pad_token_id, 15, 4, 8, 98, 32, 32, 32, 32, 4, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34], + [24, 22, 5, tokenizer.word_delimiter_token_id, tokenizer.word_delimiter_token_id, 24, 22, 22, 22, 4, 5, 77, tokenizer.pad_token_id, 22, 22, 4, 34, 34, 34, 34], + ] + # fmt: on + + # We assume that `decode` works as expected. All we will check now is + # the output type is correct and the output is identical to `decode` + + # char + outputs_char_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True) + outputs_char = [tokenizer.decode(ids, output_char_offsets=True) for ids in sample_ids] + check_list_tuples_equal(outputs_char_batch, outputs_char) + + @unittest.skip("Wav2Vec2PhonemeTokenizer always lower cases letters to correctly map to phonemes") def test_added_tokens_do_lower_case(self): - # Wav2Vec2PhonemeTokenizer always lower cases letters to correctly map to phonemes pass - # overwrite common test + @unittest.skip("Wav2Vec2PhonemeTokenizer always puts spaces between phonemes") def test_encode_decode_with_spaces(self): - # Wav2Vec2PhonemeTokenizer always puts spaces between phonemes pass - # overwrite common test + @unittest.skip("encodes to text to ids, but decodes ids to phonemes -> not possible to have internal consistency") def test_internal_consistency(self): - # encodes to text to ids, but decodes ids to phonemes -> not possible to have internal consistency pass + @unittest.skip("Wav2Vec2PhonemeModel has no max model length => no testing") def test_pretrained_model_lists(self): - # Wav2Vec2PhonemeModel has no max model length => no testing pass # overwrite common