Time stamps for CTC models (#15687)
* [Wav2Vec2 Time Stamps] * Add first version * add word time stamps * Fix * save intermediate space * improve * [Finish CTC Tokenizer] * remove @ * remove @ * push * continue with phonemes * up * finish PR * up * add example * rename * finish * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * correct split * finalize Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
32295b15a1
commit
c44d3675c2
@@ -45,6 +45,8 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
|
||||
[[autodoc]] Wav2Vec2CTCTokenizer
|
||||
- __call__
|
||||
- save_vocabulary
|
||||
- decode
|
||||
- batch_decode
|
||||
|
||||
## Wav2Vec2FeatureExtractor
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" Hubert model configuration"""
|
||||
|
||||
import math
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
@@ -248,3 +250,7 @@ class HubertConfig(PretrainedConfig):
|
||||
# ctc loss
|
||||
self.ctc_loss_reduction = ctc_loss_reduction
|
||||
self.ctc_zero_infinity = ctc_zero_infinity
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" SEW model configuration"""
|
||||
|
||||
import math
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
@@ -243,3 +245,7 @@ class SEWConfig(PretrainedConfig):
|
||||
# sequence classification
|
||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" SEW-D model configuration"""
|
||||
|
||||
import math
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
@@ -279,3 +281,7 @@ class SEWDConfig(PretrainedConfig):
|
||||
# sequence classification
|
||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" UniSpeech model configuration"""
|
||||
|
||||
import math
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
@@ -289,3 +291,7 @@ class UniSpeechConfig(PretrainedConfig):
|
||||
|
||||
# pretraining loss
|
||||
self.replace_prob = replace_prob
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" UniSpeechSat model configuration"""
|
||||
|
||||
import math
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
@@ -306,3 +308,7 @@ class UniSpeechSatConfig(PretrainedConfig):
|
||||
self.tdnn_kernel = list(tdnn_kernel)
|
||||
self.tdnn_dilation = list(tdnn_dilation)
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" Wav2Vec2 model configuration"""
|
||||
|
||||
import math
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
@@ -329,3 +331,7 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
self.tdnn_kernel = list(tdnn_kernel)
|
||||
self.tdnn_dilation = list(tdnn_dilation)
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
|
||||
@@ -18,12 +18,22 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from itertools import groupby
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
PaddingStrategy,
|
||||
TensorType,
|
||||
add_end_docstrings,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
to_py_obj,
|
||||
)
|
||||
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
|
||||
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
||||
from ...utils import logging
|
||||
@@ -32,6 +42,15 @@ from ...utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp # noqa: F401
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"tokenizer_config_file": "tokenizer_config.json",
|
||||
@@ -79,6 +98,28 @@ WAV2VEC2_KWARGS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2Vec2CTCTokenizerOutput(ModelOutput):
|
||||
"""
|
||||
Output type of [` Wav2Vec2CTCTokenizer`], with transcription.
|
||||
|
||||
Args:
|
||||
text (list of `str` or `str`):
|
||||
Decoded logits in text from. Usually the speech transcription.
|
||||
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
||||
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
|
||||
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
|
||||
produced text.
|
||||
word_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
||||
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
|
||||
can be used to compute time stamps for each word.
|
||||
"""
|
||||
|
||||
text: Union[List[str], str]
|
||||
char_offsets: List[Dict[str, Union[float, str]]] = None
|
||||
word_offsets: List[Dict[str, Union[float, str]]] = None
|
||||
|
||||
|
||||
class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
|
||||
"""
|
||||
@@ -121,6 +162,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
unk_token="<unk>",
|
||||
pad_token="<pad>",
|
||||
word_delimiter_token="|",
|
||||
replace_word_delimiter_char=" ",
|
||||
do_lower_case=False,
|
||||
**kwargs
|
||||
):
|
||||
@@ -131,12 +173,14 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
pad_token=pad_token,
|
||||
do_lower_case=do_lower_case,
|
||||
word_delimiter_token=word_delimiter_token,
|
||||
replace_word_delimiter_char=replace_word_delimiter_char,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._word_delimiter_token = word_delimiter_token
|
||||
|
||||
self.do_lower_case = do_lower_case
|
||||
self.replace_word_delimiter_char = replace_word_delimiter_char
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
@@ -204,31 +248,106 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
return result
|
||||
|
||||
def convert_tokens_to_string(
|
||||
self, tokens: List[str], group_tokens: bool = True, spaces_between_special_tokens: bool = False
|
||||
) -> str:
|
||||
self,
|
||||
tokens: List[str],
|
||||
group_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = False,
|
||||
output_char_offsets: bool = False,
|
||||
output_word_offsets: bool = False,
|
||||
) -> Dict[str, Union[str, float]]:
|
||||
"""
|
||||
Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
|
||||
"""
|
||||
# group same tokens into non-repeating tokens in CTC style decoding
|
||||
if group_tokens:
|
||||
tokens = [token_group[0] for token_group in groupby(tokens)]
|
||||
chars, char_repetitions = zip(*((token, len(list(group_iter))) for token, group_iter in groupby(tokens)))
|
||||
else:
|
||||
chars = tokens
|
||||
char_repetitions = len(tokens) * [1]
|
||||
|
||||
# filter self.pad_token which is used as CTC-blank token
|
||||
filtered_tokens = list(filter(lambda token: token != self.pad_token, tokens))
|
||||
|
||||
if spaces_between_special_tokens:
|
||||
join_token = " "
|
||||
else:
|
||||
join_token = ""
|
||||
processed_chars = list(filter(lambda char: char != self.pad_token, chars))
|
||||
|
||||
# replace delimiter token
|
||||
string = join_token.join(
|
||||
[" " if token == self.word_delimiter_token else token for token in filtered_tokens]
|
||||
).strip()
|
||||
processed_chars = [
|
||||
self.replace_word_delimiter_char if char == self.word_delimiter_token else char for char in processed_chars
|
||||
]
|
||||
|
||||
# retrieve offsets
|
||||
char_offsets = word_offsets = None
|
||||
if output_char_offsets or output_word_offsets:
|
||||
char_offsets = self._compute_offsets(char_repetitions, chars, self.pad_token)
|
||||
|
||||
if len(char_offsets) != len(processed_chars):
|
||||
raise ValueError(
|
||||
f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}"
|
||||
" have to be of the same length, but are: "
|
||||
f"`len(offsets)`: {len(char_offsets)} and `len(processed_tokens)`:"
|
||||
f" {len(processed_chars)}"
|
||||
)
|
||||
|
||||
# set tokens to correct processed token
|
||||
for i, char in enumerate(processed_chars):
|
||||
char_offsets[i]["char"] = char
|
||||
|
||||
# retrieve word offsets from character offsets
|
||||
word_offsets = None
|
||||
if output_word_offsets:
|
||||
word_offsets = self._get_word_offsets(char_offsets, self.replace_word_delimiter_char)
|
||||
|
||||
# join to string
|
||||
join_char = " " if spaces_between_special_tokens else ""
|
||||
string = join_char.join(processed_chars).strip()
|
||||
|
||||
if self.do_lower_case:
|
||||
string = string.lower()
|
||||
return string
|
||||
|
||||
return {"text": string, "char_offsets": char_offsets, "word_offsets": word_offsets}
|
||||
|
||||
@staticmethod
|
||||
def _compute_offsets(
|
||||
char_repetitions: List[int], chars: List[str], ctc_token: int
|
||||
) -> List[Dict[str, Union[str, int]]]:
|
||||
end_indices = np.asarray(char_repetitions).cumsum()
|
||||
start_indices = np.concatenate(([0], end_indices[:-1]))
|
||||
|
||||
offsets = [
|
||||
{"char": t, "start_offset": s, "end_offset": e} for t, s, e in zip(chars, start_indices, end_indices)
|
||||
]
|
||||
|
||||
# filter out CTC token
|
||||
offsets = list(filter(lambda offsets: offsets["char"] != ctc_token, offsets))
|
||||
return offsets
|
||||
|
||||
@staticmethod
|
||||
def _get_word_offsets(
|
||||
offsets: Dict[str, Union[str, float]], word_delimiter_char: str = " "
|
||||
) -> Dict[str, Union[str, float]]:
|
||||
word_offsets = []
|
||||
final_offset_idx = len(offsets) - 1
|
||||
|
||||
for i, offset in enumerate(offsets):
|
||||
# define previous, next and current char
|
||||
char = offset["char"]
|
||||
prev_char = offsets[i - 1]["char"] if i > 0 else None
|
||||
next_char = offsets[i + 1]["char"] if i < final_offset_idx else None
|
||||
|
||||
# derive whether word begins, ends and whether current char is in word
|
||||
word_begin = (i == 0 and char != word_delimiter_char) or (prev_char == word_delimiter_char)
|
||||
word_end = (i == final_offset_idx and char != word_delimiter_char) or (next_char == word_delimiter_char)
|
||||
char_is_in_word = char != word_delimiter_char
|
||||
|
||||
if word_begin:
|
||||
word_offset = {"word": "", "start_offset": offset["start_offset"]}
|
||||
|
||||
if word_end:
|
||||
word_offset["end_offset"] = offset["end_offset"]
|
||||
word_offsets.append(word_offset)
|
||||
|
||||
if char_is_in_word:
|
||||
word_offset["word"] += offset["char"]
|
||||
|
||||
return word_offsets
|
||||
|
||||
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
||||
if is_split_into_words:
|
||||
@@ -242,6 +361,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
group_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = False,
|
||||
output_word_offsets: Optional[bool] = False,
|
||||
output_char_offsets: Optional[bool] = False,
|
||||
) -> str:
|
||||
"""
|
||||
special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
|
||||
@@ -256,16 +377,210 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
continue
|
||||
result.append(token)
|
||||
|
||||
text = self.convert_tokens_to_string(
|
||||
result, group_tokens=group_tokens, spaces_between_special_tokens=spaces_between_special_tokens
|
||||
string_output = self.convert_tokens_to_string(
|
||||
result,
|
||||
group_tokens=group_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
output_word_offsets=output_word_offsets,
|
||||
output_char_offsets=output_char_offsets,
|
||||
)
|
||||
|
||||
text = string_output["text"]
|
||||
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = self.clean_up_tokenization(text)
|
||||
return clean_text
|
||||
text = self.clean_up_tokenization(text)
|
||||
|
||||
if output_word_offsets or output_char_offsets:
|
||||
return Wav2Vec2CTCTokenizerOutput(
|
||||
text=text,
|
||||
char_offsets=string_output["char_offsets"],
|
||||
word_offsets=string_output["word_offsets"],
|
||||
)
|
||||
else:
|
||||
return text
|
||||
|
||||
# overwritten from `tokenization_utils_base.py` because tokenizer can output
|
||||
# `ModelOutput` which should not be a list for batched output and
|
||||
# because we need docs for `output_char_offsets` here
|
||||
def batch_decode(
|
||||
self,
|
||||
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
output_char_offsets: bool = False,
|
||||
output_word_offsets: bool = False,
|
||||
**kwargs
|
||||
) -> List[str]:
|
||||
"""
|
||||
Convert a list of lists of token ids into a list of strings by calling decode.
|
||||
|
||||
Args:
|
||||
sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean up the tokenization spaces.
|
||||
output_char_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output character offsets. Character offsets can be used in combination with the
|
||||
sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||
understand how to make use of `output_word_offsets`.
|
||||
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
|
||||
|
||||
</Tip>
|
||||
|
||||
output_word_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
|
||||
and model downsampling rate to compute the time-stamps of transcribed words.
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||
understand how to make use of `output_word_offsets`.
|
||||
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
|
||||
|
||||
</Tip>
|
||||
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
|
||||
Returns:
|
||||
`List[str]` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded
|
||||
sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when
|
||||
`output_char_offsets == True` or `output_word_offsets == True`.
|
||||
"""
|
||||
batch_decoded = [
|
||||
self.decode(
|
||||
seq,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
output_char_offsets=output_char_offsets,
|
||||
output_word_offsets=output_word_offsets,
|
||||
**kwargs,
|
||||
)
|
||||
for seq in sequences
|
||||
]
|
||||
if output_char_offsets or output_word_offsets:
|
||||
# transform list of dicts to dict of lists
|
||||
return Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in batch_decoded] for k in batch_decoded[0]})
|
||||
|
||||
return batch_decoded
|
||||
|
||||
# overwritten from `tokenization_utils_base.py` because we need docs for `output_char_offsets`
|
||||
# and `output_word_offsets` here
|
||||
def decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
output_char_offsets: bool = False,
|
||||
output_word_offsets: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
||||
tokens and clean up tokenization spaces.
|
||||
|
||||
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
||||
|
||||
Args:
|
||||
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean up the tokenization spaces.
|
||||
output_char_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output character offsets. Character offsets can be used in combination with the
|
||||
sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||
understand how to make use of `output_word_offsets`.
|
||||
|
||||
</Tip>
|
||||
|
||||
output_word_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
|
||||
and model downsampling rate to compute the time-stamps of transcribed words.
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||
understand how to make use of `output_word_offsets`.
|
||||
|
||||
</Tip>
|
||||
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
|
||||
Returns:
|
||||
`str` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded
|
||||
sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when
|
||||
`output_char_offsets == True` or `output_word_offsets == True`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Let's see how to retrieve time steps for a model
|
||||
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
|
||||
>>> from datasets import load_dataset
|
||||
>>> import datasets
|
||||
>>> import torch
|
||||
|
||||
>>> # import model, feature extractor, tokenizer
|
||||
>>> model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
>>> # load first sample of English common_voice
|
||||
>>> dataset = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
>>> dataset_iter = iter(dataset)
|
||||
>>> sample = next(dataset_iter)
|
||||
|
||||
>>> # forward sample through model to get greedily predicted transcription ids
|
||||
>>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||
>>> logits = model(input_values).logits[0]
|
||||
>>> pred_ids = torch.argmax(logits, axis=-1)
|
||||
|
||||
>>> # retrieve word stamps (analogous commands for `output_char_offsets`)
|
||||
>>> outputs = tokenizer.decode(pred_ids, output_word_offsets=True)
|
||||
>>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
|
||||
>>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
|
||||
|
||||
>>> word_offsets = [
|
||||
... {
|
||||
... "word": d["word"],
|
||||
... "start_time": d["start_offset"] * time_offset,
|
||||
... "end_time": d["end_offset"] * time_offset,
|
||||
... }
|
||||
... for d in outputs.word_offsets
|
||||
... ]
|
||||
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
|
||||
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
|
||||
>>> word_offset
|
||||
>>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES',
|
||||
>>> # 'start_time': 1.64, 'end_time': 1.90}, {'word': 'MILISANDRA',
|
||||
>>> # 'start_time': 2.26, 'end_time': 2.9}, {'word': 'LOOK', 'start_time': 3.0, 'end_time': 3.16}, ...
|
||||
```"""
|
||||
# Convert inputs to python lists
|
||||
token_ids = to_py_obj(token_ids)
|
||||
|
||||
return self._decode(
|
||||
token_ids=token_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
output_char_offsets=output_char_offsets,
|
||||
output_word_offsets=output_word_offsets,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
@@ -294,7 +609,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
Returns:
|
||||
`int`: The number of tokens actually added to the vocabulary.
|
||||
|
||||
Examples:
|
||||
Example:
|
||||
|
||||
```python
|
||||
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
||||
@@ -551,6 +866,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||
|
||||
if self.do_lower_case:
|
||||
string = string.lower()
|
||||
|
||||
return string
|
||||
|
||||
def _decode(
|
||||
|
||||
@@ -17,10 +17,20 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from itertools import groupby
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ...file_utils import requires_backends
|
||||
import numpy as np
|
||||
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
requires_backends,
|
||||
to_py_obj,
|
||||
)
|
||||
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
|
||||
from ...tokenization_utils_base import AddedToken
|
||||
from ...utils import logging
|
||||
@@ -29,6 +39,15 @@ from ...utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp # noqa: F401
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"tokenizer_config_file": "tokenizer_config.json",
|
||||
@@ -47,6 +66,24 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
|
||||
"""
|
||||
Output type of [` Wav2Vec2PhonemeCTCTokenizer`], with transcription.
|
||||
|
||||
Args:
|
||||
text (list of `str` or `str`):
|
||||
Decoded logits in text from. Usually the speech transcription.
|
||||
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
||||
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
|
||||
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
|
||||
produced text.
|
||||
"""
|
||||
|
||||
text: Union[List[str], str]
|
||||
char_offsets: List[Dict[str, Union[float, str]]] = None
|
||||
|
||||
|
||||
class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
||||
|
||||
"""
|
||||
@@ -284,24 +321,69 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
||||
group_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = False,
|
||||
filter_word_delimiter_token: bool = True,
|
||||
output_char_offsets: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
|
||||
"""
|
||||
# group same tokens into non-repeating tokens in CTC style decoding
|
||||
if group_tokens:
|
||||
tokens = [token_group[0] for token_group in groupby(tokens)]
|
||||
chars, char_repetitions = zip(*((token, len(list(group_iter))) for token, group_iter in groupby(tokens)))
|
||||
else:
|
||||
chars = tokens
|
||||
char_repetitions = len(tokens) * [1]
|
||||
|
||||
# filter self.pad_token which is used as CTC-blank token
|
||||
filtered_tokens = list(filter(lambda token: token != self.pad_token, tokens))
|
||||
processed_chars = list(filter(lambda char: char != self.pad_token, chars))
|
||||
|
||||
# also filter self.word_delimiter_token if not not
|
||||
if filter_word_delimiter_token and self.word_delimiter_token is not None:
|
||||
filtered_tokens = list(filter(lambda token: token != self.word_delimiter_token, filtered_tokens))
|
||||
processed_chars = list(filter(lambda token: token != self.word_delimiter_token, processed_chars))
|
||||
|
||||
string = " ".join(filtered_tokens).strip()
|
||||
# retrieve offsets
|
||||
char_offsets = None
|
||||
if output_char_offsets:
|
||||
word_delimiter_token_for_offsets = (
|
||||
self.word_delimiter_token if filter_word_delimiter_token is True else None
|
||||
)
|
||||
char_offsets = self._compute_offsets(
|
||||
char_repetitions, chars, self.pad_token, word_delimiter_token=word_delimiter_token_for_offsets
|
||||
)
|
||||
|
||||
return string
|
||||
if len(char_offsets) != len(processed_chars):
|
||||
raise ValueError(
|
||||
f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}"
|
||||
f" have to be of the same length, but are: `len(offsets)`: "
|
||||
f"{len(char_offsets)} and `len(processed_tokens)`: {len(processed_chars)}"
|
||||
)
|
||||
|
||||
# set tokens to correct processed token
|
||||
for i, char in enumerate(processed_chars):
|
||||
char_offsets[i]["char"] = char
|
||||
|
||||
string = " ".join(processed_chars).strip()
|
||||
|
||||
return {"text": string, "char_offsets": char_offsets}
|
||||
|
||||
@staticmethod
|
||||
def _compute_offsets(
|
||||
char_repetitions: List[int], chars: List[str], ctc_token: int, word_delimiter_token: Optional[int] = None
|
||||
) -> List[Dict[str, Union[str, int]]]:
|
||||
end_indices = np.asarray(char_repetitions).cumsum()
|
||||
start_indices = np.concatenate(([0], end_indices[:-1]))
|
||||
|
||||
offsets = [
|
||||
{"char": t, "start_offset": s, "end_offset": e} for t, s, e in zip(chars, start_indices, end_indices)
|
||||
]
|
||||
|
||||
# filter out CTC token
|
||||
offsets = list(filter(lambda offsets: offsets["char"] != ctc_token, offsets))
|
||||
|
||||
# filter out word delimiter token if necessary
|
||||
if word_delimiter_token is not None:
|
||||
offsets = list(filter(lambda offsets: offsets["char"] != word_delimiter_token, offsets))
|
||||
|
||||
return offsets
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
@@ -311,6 +393,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
||||
group_tokens: bool = True,
|
||||
filter_word_delimiter_token: bool = True,
|
||||
spaces_between_special_tokens: bool = False,
|
||||
output_char_offsets: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
special _decode function is needed for Wav2Vec2PhonemeTokenizer because added tokens should be treated exactly
|
||||
@@ -325,19 +408,137 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
||||
continue
|
||||
result.append(token)
|
||||
|
||||
text = self.convert_tokens_to_string(
|
||||
string_output = self.convert_tokens_to_string(
|
||||
result,
|
||||
group_tokens=group_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
filter_word_delimiter_token=filter_word_delimiter_token,
|
||||
output_char_offsets=output_char_offsets,
|
||||
)
|
||||
|
||||
text = string_output["text"]
|
||||
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = self.clean_up_tokenization(text)
|
||||
return clean_text
|
||||
text = self.clean_up_tokenization(text)
|
||||
|
||||
if output_char_offsets:
|
||||
return Wav2Vec2PhonemeCTCTokenizerOutput(text=text, char_offsets=string_output["char_offsets"])
|
||||
else:
|
||||
return text
|
||||
|
||||
# overwritten from `tokenization_utils_base.py` because we need docs for `output_char_offsets` here
|
||||
def decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
output_char_offsets: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
||||
tokens and clean up tokenization spaces.
|
||||
|
||||
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
||||
|
||||
Args:
|
||||
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean up the tokenization spaces.
|
||||
output_char_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output character offsets. Character offsets can be used in combination with the
|
||||
sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||
understand how to make use of `output_word_offsets`.
|
||||
[`~model.wav2vec2_phoneme.tokenization_wav2vec2_phoneme.batch_decode`] works the same way with
|
||||
phonemes.
|
||||
|
||||
</Tip>
|
||||
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
|
||||
Returns:
|
||||
`str` or [`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`]: The decoded
|
||||
sentence. Will be a [`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`]
|
||||
when `output_char_offsets == True`.
|
||||
"""
|
||||
# Convert inputs to python lists
|
||||
token_ids = to_py_obj(token_ids)
|
||||
|
||||
return self._decode(
|
||||
token_ids=token_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
output_char_offsets=output_char_offsets,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# overwritten from `tokenization_utils_base.py` because tokenizer can output
|
||||
# `ModelOutput` which should not be a list for batched output and because
|
||||
# we need docs for `output_char_offsets` here
|
||||
def batch_decode(
|
||||
self,
|
||||
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
output_char_offsets: bool = False,
|
||||
**kwargs
|
||||
) -> List[str]:
|
||||
"""
|
||||
Convert a list of lists of token ids into a list of strings by calling decode.
|
||||
|
||||
Args:
|
||||
sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean up the tokenization spaces.
|
||||
output_char_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output character offsets. Character offsets can be used in combination with the
|
||||
sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
|
||||
|
||||
<Tip>
|
||||
|
||||
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||
understand how to make use of `output_word_offsets`.
|
||||
[`~model.wav2vec2_phoneme.tokenization_wav2vec2_phoneme.batch_decode`] works analogous with phonemes
|
||||
and batched output.
|
||||
|
||||
</Tip>
|
||||
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
|
||||
Returns:
|
||||
`List[str]` or [`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`]: The
|
||||
decoded sentence. Will be a
|
||||
[`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`] when
|
||||
`output_char_offsets == True`.
|
||||
"""
|
||||
batch_decoded = [
|
||||
self.decode(
|
||||
seq,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
output_char_offsets=output_char_offsets,
|
||||
**kwargs,
|
||||
)
|
||||
for seq in sequences
|
||||
]
|
||||
if output_char_offsets:
|
||||
# transform list of dicts to dict of lists
|
||||
return Wav2Vec2PhonemeCTCTokenizerOutput({k: [d[k] for d in batch_decoded] for k in batch_decoded[0]})
|
||||
|
||||
return batch_decoded
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" WavLM model configuration"""
|
||||
|
||||
import math
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
@@ -330,3 +332,7 @@ class WavLMConfig(PretrainedConfig):
|
||||
self.tdnn_kernel = list(tdnn_kernel)
|
||||
self.tdnn_dilation = list(tdnn_dilation)
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return math.prod(self.conv_stride)
|
||||
|
||||
@@ -29,7 +29,7 @@ from transformers import (
|
||||
Wav2Vec2CTCTokenizer,
|
||||
Wav2Vec2Tokenizer,
|
||||
)
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2CTCTokenizerOutput
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -422,27 +422,16 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_tokenizer_decode_special(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||
]
|
||||
sample_ids_2 = [
|
||||
[11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[
|
||||
24,
|
||||
22,
|
||||
5,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
24,
|
||||
22,
|
||||
5,
|
||||
77,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
],
|
||||
[24, 22, 5, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.word_delimiter_token_id],
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
|
||||
@@ -454,27 +443,12 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer.add_tokens(["!", "?"])
|
||||
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[
|
||||
11,
|
||||
5,
|
||||
15,
|
||||
tokenizer.pad_token_id,
|
||||
15,
|
||||
8,
|
||||
98,
|
||||
32,
|
||||
32,
|
||||
33,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
32,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
34,
|
||||
],
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
||||
]
|
||||
# fmt: on
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||
@@ -499,6 +473,187 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||
self.assertEqual(sent, expected_sent)
|
||||
|
||||
@staticmethod
|
||||
def get_from_offsets(offsets, key):
|
||||
retrieved_list = [d[key] for d in offsets]
|
||||
return retrieved_list
|
||||
|
||||
def test_offsets(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# fmt: off
|
||||
# HEEEEE||LLL<pad>LO<unk> => HE LLO<unk>
|
||||
# 1H + 5E + 2| + 3L + 1<pad> + 1L + 1O + 1<unk>
|
||||
sample_ids = [11, 5, 5, 5, 5, 5, 4, 4, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98]
|
||||
# fmt: on
|
||||
|
||||
outputs_char = tokenizer.decode(sample_ids, output_char_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for char
|
||||
self.assertTrue(len(outputs_char.keys()), 2)
|
||||
self.assertTrue("text" in outputs_char)
|
||||
self.assertTrue("char_offsets" in outputs_char)
|
||||
self.assertTrue(isinstance(outputs_char, Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
outputs_word = tokenizer.decode(sample_ids, output_word_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for word
|
||||
self.assertTrue(len(outputs_word.keys()), 2)
|
||||
self.assertTrue("text" in outputs_word)
|
||||
self.assertTrue("word_offsets" in outputs_word)
|
||||
self.assertTrue(isinstance(outputs_word, Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, output_word_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for both
|
||||
self.assertTrue(len(outputs.keys()), 3)
|
||||
self.assertTrue("text" in outputs)
|
||||
self.assertTrue("char_offsets" in outputs)
|
||||
self.assertTrue("word_offsets" in outputs)
|
||||
self.assertTrue(isinstance(outputs, Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
# check that order of chars is correct and identical for both outputs
|
||||
self.assertEqual("".join(self.get_from_offsets(outputs["char_offsets"], "char")), outputs.text)
|
||||
self.assertEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "char"), ["H", "E", " ", "L", "L", "O", "<unk>"]
|
||||
)
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "char"),
|
||||
self.get_from_offsets(outputs_char["char_offsets"], "char"),
|
||||
)
|
||||
|
||||
# check that order of words is correct and identical to both outputs
|
||||
self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text)
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["HE", "LLO<unk>"])
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["word_offsets"], "word"),
|
||||
self.get_from_offsets(outputs_word["word_offsets"], "word"),
|
||||
)
|
||||
|
||||
# check that offsets are actually correct for char
|
||||
# 0 is H, 1 is E, 6 is | (" "), 8 is 1st L, 12 is 2nd L, 13 is O, 14 is <unk>
|
||||
self.assertListEqual(self.get_from_offsets(outputs["char_offsets"], "start_offset"), [0, 1, 6, 8, 12, 13, 14])
|
||||
# 1 is H, 6 is E, 8 is | (" "), 11 is 1st L (note due to <pad>
|
||||
# different begin of 2nd L), 13 is 2nd L, 14 is O, 15 is <unk>
|
||||
self.assertListEqual(self.get_from_offsets(outputs["char_offsets"], "end_offset"), [1, 6, 8, 11, 13, 14, 15])
|
||||
|
||||
# check that offsets are actually correct for word
|
||||
# H is at 1st position of first word, first L is at 8th position of second word
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 8])
|
||||
# last E is at 6th position of first word, first L is at last (15th) position of second word
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [6, 15])
|
||||
|
||||
def test_offsets_batch(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
def check_list_tuples_equal(outputs_batch, outputs_list):
|
||||
self.assertTrue(isinstance(outputs_batch, Wav2Vec2CTCTokenizerOutput))
|
||||
self.assertTrue(isinstance(outputs_list[0], Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
# transform list to ModelOutput
|
||||
outputs_batch_2 = Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in outputs_list] for k in outputs_list[0]})
|
||||
|
||||
self.assertListEqual(outputs_batch["text"], outputs_batch_2["text"])
|
||||
|
||||
def recursive_check(list_or_dict_1, list_or_dict_2):
|
||||
if isinstance(list_or_dict_1, list):
|
||||
[recursive_check(l1, l2) for l1, l2 in zip(list_or_dict_1, list_or_dict_2)]
|
||||
self.assertEqual(list_or_dict_1, list_or_dict_2)
|
||||
|
||||
if "char_offsets" in outputs_batch:
|
||||
recursive_check(outputs_batch["char_offsets"], outputs_batch_2["char_offsets"])
|
||||
|
||||
if "word_offsets" in outputs_batch:
|
||||
recursive_check(outputs_batch["word_offsets"], outputs_batch_2["word_offsets"])
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 4, 8, 98, 32, 32, 32, 32, 4, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, tokenizer.word_delimiter_token_id, 24, 22, 22, 22, 4, 5, 77, tokenizer.pad_token_id, 22, 22, 4, 34, 34, 34, 34],
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# We assume that `decode` works as expected. All we will check now is
|
||||
# the output type is correct and the output is identical to `decode`
|
||||
|
||||
# char
|
||||
outputs_char_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True)
|
||||
outputs_char = [tokenizer.decode(ids, output_char_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_char_batch, outputs_char)
|
||||
|
||||
# word
|
||||
outputs_word_batch = tokenizer.batch_decode(sample_ids, output_word_offsets=True)
|
||||
outputs_word = [tokenizer.decode(ids, output_word_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_word_batch, outputs_word)
|
||||
|
||||
# both
|
||||
outputs_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True, output_word_offsets=True)
|
||||
outputs = [tokenizer.decode(ids, output_word_offsets=True, output_char_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_batch, outputs)
|
||||
|
||||
def test_offsets_integration(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
# pred_ids correspond to the following code
|
||||
# ```
|
||||
# from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
|
||||
# from datasets import load_dataset
|
||||
# import datasets
|
||||
# import torch
|
||||
# model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
# feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
#
|
||||
# ds = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||
# ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
# ds_iter = iter(ds)
|
||||
# sample = next(ds_iter)
|
||||
#
|
||||
# input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||
# logits = model(input_values).logits
|
||||
# pred_ids = torch.argmax(logits, axis=-1).cpu().tolist()
|
||||
# ```
|
||||
# fmt: off
|
||||
pred_ids = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 11, 0, 0, 0, 22, 0, 0, 4, 4, 4, 14, 0, 0, 0, 0, 0, 8, 8, 0, 5, 5, 0, 12, 0, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 10, 0, 0, 0, 15, 0, 0, 10, 0, 0, 0, 12, 0, 0, 0, 0, 0, 7, 0, 9, 0, 0, 14, 0, 0, 0, 13, 0, 7, 0, 0, 4, 4, 0, 15, 8, 8, 0, 0, 8, 0, 26, 0, 0, 4, 4, 0, 0, 15, 0, 0, 0, 0, 0, 0, 10, 0, 26, 5, 5, 0, 4, 4, 0, 0, 12, 11, 0, 0, 5, 4, 4, 4, 0, 18, 0, 0, 0, 7, 9, 9, 0, 6, 0, 12, 12, 4, 4, 0, 6, 0, 0, 8, 0, 4, 4, 4, 0, 19, 0, 0, 8, 9, 9, 0, 0, 0, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 16, 16, 0, 0, 17, 5, 5, 5, 0, 4, 4, 4, 0, 0, 29, 29, 0, 0, 0, 0, 8, 11, 0, 9, 9, 0, 0, 0, 4, 4, 0, 12, 12, 0, 0, 0, 9, 0, 0, 0, 0, 0, 8, 18, 0, 0, 0, 4, 4, 0, 0, 8, 9, 0, 4, 4, 0, 6, 11, 5, 0, 4, 4, 0, 13, 13, 0, 0, 0, 10, 0, 0, 25, 0, 0, 6, 0, 4, 4, 0, 0, 0, 0, 7, 0, 0, 23, 0, 0, 4, 4, 0, 0, 0, 6, 11, 0, 5, 4, 4, 18, 0, 0, 0, 0, 0, 0, 7, 15, 0, 0, 0, 15, 15, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
|
||||
# wav2vec2-base downsamples input audio by a factor of 320
|
||||
# sampling rate for wav2vec2-base is 16_000
|
||||
time_offset_wav2vec2_base = 320 / 16_000
|
||||
|
||||
expected_char_time_stamps_text = ['W', 'H', 'Y', ' ', 'D', 'O', 'E', 'S', ' ', 'M', 'I', 'L', 'I', 'S', 'A', 'N', 'D', 'R', 'A', ' ', 'L', 'O', 'O', 'K', ' ', 'L', 'I', 'K', 'E', ' ', 'S', 'H', 'E', ' ', 'W', 'A', 'N', 'T', 'S', ' ', 'T', 'O', ' ', 'C', 'O', 'N', 'S', 'U', 'M', 'E', ' ', 'J', 'O', 'H', 'N', ' ', 'S', 'N', 'O', 'W', ' ', 'O', 'N', ' ', 'T', 'H', 'E', ' ', 'R', 'I', 'V', 'T', ' ', 'A', 'P', ' ', 'T', 'H', 'E', ' ', 'W', 'A', 'L', 'L', ' ']
|
||||
expected_char_time_stamps_start = [1.42, 1.44, 1.52, 1.58, 1.64, 1.76, 1.82, 1.88, 1.92, 2.26, 2.32, 2.4, 2.46, 2.54, 2.66, 2.7, 2.76, 2.84, 2.88, 2.94, 3.0, 3.02, 3.1, 3.14, 3.2, 3.28, 3.42, 3.46, 3.48, 3.54, 3.62, 3.64, 3.7, 3.72, 3.8, 3.88, 3.9, 3.96, 4.0, 4.04, 4.1, 4.16, 4.2, 4.28, 4.34, 4.36, 4.48, 4.66, 4.74, 4.76, 4.84, 4.94, 5.06, 5.08, 5.12, 5.22, 5.28, 5.38, 5.5, 5.52, 5.6, 5.68, 5.7, 5.74, 5.8, 5.82, 5.84, 5.88, 5.94, 6.04, 6.1, 6.16, 6.2, 6.32, 6.38, 6.44, 6.54, 6.56, 6.6, 6.62, 6.66, 6.8, 6.82, 6.9, 6.96]
|
||||
expected_char_time_stamps_end = [1.44, 1.46, 1.54, 1.64, 1.66, 1.8, 1.86, 1.9, 2.06, 2.28, 2.34, 2.42, 2.48, 2.56, 2.68, 2.72, 2.78, 2.86, 2.9, 2.98, 3.02, 3.06, 3.12, 3.16, 3.24, 3.3, 3.44, 3.48, 3.52, 3.58, 3.64, 3.66, 3.72, 3.78, 3.82, 3.9, 3.94, 3.98, 4.04, 4.08, 4.12, 4.18, 4.26, 4.3, 4.36, 4.4, 4.52, 4.7, 4.76, 4.82, 4.9, 4.98, 5.08, 5.1, 5.16, 5.26, 5.32, 5.4, 5.52, 5.54, 5.64, 5.7, 5.72, 5.78, 5.82, 5.84, 5.86, 5.92, 5.98, 6.06, 6.12, 6.18, 6.24, 6.34, 6.4, 6.48, 6.56, 6.58, 6.62, 6.66, 6.68, 6.82, 6.84, 6.94, 7.02]
|
||||
|
||||
expected_word_time_stamps_text = ['WHY', 'DOES', 'MILISANDRA', 'LOOK', 'LIKE', 'SHE', 'WANTS', 'TO', 'CONSUME', 'JOHN', 'SNOW', 'ON', 'THE', 'RIVT', 'AP', 'THE', 'WALL']
|
||||
expected_word_time_stamps_start = [1.42, 1.64, 2.26, 3.0, 3.28, 3.62, 3.8, 4.1, 4.28, 4.94, 5.28, 5.68, 5.8, 5.94, 6.32, 6.54, 6.66]
|
||||
expected_word_time_stamps_end = [1.54, 1.9, 2.9, 3.16, 3.52, 3.72, 4.04, 4.18, 4.82, 5.16, 5.54, 5.72, 5.86, 6.18, 6.4, 6.62, 6.94]
|
||||
# fmt: on
|
||||
|
||||
output = tokenizer.batch_decode(pred_ids, output_char_offsets=True, output_word_offsets=True)
|
||||
|
||||
char_offsets_text = self.get_from_offsets(output["char_offsets"][0], "char")
|
||||
char_offsets_start = self.get_from_offsets(output["char_offsets"][0], "start_offset")
|
||||
char_offsets_end = self.get_from_offsets(output["char_offsets"][0], "end_offset")
|
||||
|
||||
word_offsets_text = self.get_from_offsets(output["word_offsets"][0], "word")
|
||||
word_offsets_start = self.get_from_offsets(output["word_offsets"][0], "start_offset")
|
||||
word_offsets_end = self.get_from_offsets(output["word_offsets"][0], "end_offset")
|
||||
|
||||
# let's transform offsets to time stamps in seconds
|
||||
char_time_stamps_start = [round(c * time_offset_wav2vec2_base, 2) for c in char_offsets_start]
|
||||
char_time_stamps_end = [round(c * time_offset_wav2vec2_base, 2) for c in char_offsets_end]
|
||||
|
||||
word_time_stamps_start = [round(w * time_offset_wav2vec2_base, 2) for w in word_offsets_start]
|
||||
word_time_stamps_end = [round(w * time_offset_wav2vec2_base, 2) for w in word_offsets_end]
|
||||
|
||||
# NOTE: you can verify the above results by checking out the dataset viewer
|
||||
# on https://huggingface.co/datasets/common_voice/viewer/en/train and
|
||||
# downloading / playing the sample `common_voice_en_100038.mp3`. As
|
||||
# you can hear the time-stamps match more or less
|
||||
|
||||
self.assertListEqual(expected_char_time_stamps_text, char_offsets_text)
|
||||
self.assertListEqual(expected_char_time_stamps_start, char_time_stamps_start)
|
||||
self.assertListEqual(expected_char_time_stamps_end, char_time_stamps_end)
|
||||
|
||||
self.assertListEqual(expected_word_time_stamps_text, word_offsets_text)
|
||||
self.assertListEqual(expected_word_time_stamps_start, word_time_stamps_start)
|
||||
self.assertListEqual(expected_word_time_stamps_end, word_time_stamps_end)
|
||||
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2Model has no max model length => no testing
|
||||
pass
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Tuple
|
||||
|
||||
from transformers import Wav2Vec2PhonemeCTCTokenizer
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.models.wav2vec2_phoneme.tokenization_wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizerOutput
|
||||
from transformers.testing_utils import require_phonemizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -248,23 +249,94 @@ class Wav2Vec2PhonemeCTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ!?!? $$$", "j ð s j ð s oːɹ $$$"])
|
||||
|
||||
# overwrite common test
|
||||
@staticmethod
|
||||
def get_from_offsets(offsets, key):
|
||||
retrieved_list = [d[key] for d in offsets]
|
||||
return retrieved_list
|
||||
|
||||
def test_offsets(self):
|
||||
tokenizer = self.get_tokenizer(word_delimiter_token="|")
|
||||
tokenizer.add_tokens("|")
|
||||
|
||||
# fmt: off
|
||||
# ksssɾɾ|ɾɾ<pad>ɾɾ|<pad>ɾlll|ɭʲ -> k s ɾ ɾ | ɾ l | ɭʲ"
|
||||
sample_ids = [11, 5, 5, 5, 15, 15, tokenizer.pad_token_id, 15, 15, tokenizer.word_delimiter_token_id, tokenizer.pad_token_id, 15, 8, 8, 8, tokenizer.word_delimiter_token_id, 98]
|
||||
# fmt: on
|
||||
|
||||
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, filter_word_delimiter_token=False)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for char
|
||||
self.assertTrue(len(outputs.keys()), 2)
|
||||
self.assertTrue("text" in outputs)
|
||||
self.assertTrue("char_offsets" in outputs)
|
||||
self.assertTrue(isinstance(outputs, Wav2Vec2PhonemeCTCTokenizerOutput))
|
||||
|
||||
# check that order of chars is correct and identical for both outputs
|
||||
self.assertEqual(" ".join(self.get_from_offsets(outputs["char_offsets"], "char")), outputs.text)
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "char"), ["k", "s", "ɾ", "ɾ", "|", "ɾ", "l", "|", "ɭʲ"]
|
||||
)
|
||||
|
||||
# check that offsets are actually correct for char
|
||||
# 0-1 is 11, 1-4 is 5, 4-6 is first 15, 6-7 is <pad> (thus not shown), 7-9 is second 15, 9-10 is word_delimiter_token,
|
||||
# 10-11 is <pad> (thus not shown), 11-12 is third 15, 12-15 is 8, 15-16 is word_delimiter_token, 16-17 is 98
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "start_offset"), [0, 1, 4, 7, 9, 11, 12, 15, 16]
|
||||
)
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "end_offset"), [1, 4, 6, 9, 10, 12, 15, 16, 17]
|
||||
)
|
||||
|
||||
def test_offsets_batch(self):
|
||||
tokenizer = self.get_tokenizer(word_delimiter_token="|")
|
||||
|
||||
def check_list_tuples_equal(outputs_batch, outputs_list):
|
||||
self.assertTrue(isinstance(outputs_batch, Wav2Vec2PhonemeCTCTokenizerOutput))
|
||||
self.assertTrue(isinstance(outputs_list[0], Wav2Vec2PhonemeCTCTokenizerOutput))
|
||||
|
||||
# transform list to ModelOutput
|
||||
outputs_batch_2 = Wav2Vec2PhonemeCTCTokenizerOutput(
|
||||
{k: [d[k] for d in outputs_list] for k in outputs_list[0]}
|
||||
)
|
||||
|
||||
self.assertListEqual(outputs_batch["text"], outputs_batch_2["text"])
|
||||
|
||||
def recursive_check(list_or_dict_1, list_or_dict_2):
|
||||
if isinstance(list_or_dict_1, list):
|
||||
[recursive_check(l1, l2) for l1, l2 in zip(list_or_dict_1, list_or_dict_2)]
|
||||
self.assertEqual(list_or_dict_1, list_or_dict_2)
|
||||
|
||||
if "char_offsets" in outputs_batch:
|
||||
recursive_check(outputs_batch["char_offsets"], outputs_batch_2["char_offsets"])
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 4, 8, 98, 32, 32, 32, 32, 4, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, tokenizer.word_delimiter_token_id, 24, 22, 22, 22, 4, 5, 77, tokenizer.pad_token_id, 22, 22, 4, 34, 34, 34, 34],
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# We assume that `decode` works as expected. All we will check now is
|
||||
# the output type is correct and the output is identical to `decode`
|
||||
|
||||
# char
|
||||
outputs_char_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True)
|
||||
outputs_char = [tokenizer.decode(ids, output_char_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_char_batch, outputs_char)
|
||||
|
||||
@unittest.skip("Wav2Vec2PhonemeTokenizer always lower cases letters to correctly map to phonemes")
|
||||
def test_added_tokens_do_lower_case(self):
|
||||
# Wav2Vec2PhonemeTokenizer always lower cases letters to correctly map to phonemes
|
||||
pass
|
||||
|
||||
# overwrite common test
|
||||
@unittest.skip("Wav2Vec2PhonemeTokenizer always puts spaces between phonemes")
|
||||
def test_encode_decode_with_spaces(self):
|
||||
# Wav2Vec2PhonemeTokenizer always puts spaces between phonemes
|
||||
pass
|
||||
|
||||
# overwrite common test
|
||||
@unittest.skip("encodes to text to ids, but decodes ids to phonemes -> not possible to have internal consistency")
|
||||
def test_internal_consistency(self):
|
||||
# encodes to text to ids, but decodes ids to phonemes -> not possible to have internal consistency
|
||||
pass
|
||||
|
||||
@unittest.skip("Wav2Vec2PhonemeModel has no max model length => no testing")
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2PhonemeModel has no max model length => no testing
|
||||
pass
|
||||
|
||||
# overwrite common
|
||||
|
||||
Reference in New Issue
Block a user