Time stamps for CTC models (#15687)
* [Wav2Vec2 Time Stamps] * Add first version * add word time stamps * Fix * save intermediate space * improve * [Finish CTC Tokenizer] * remove @ * remove @ * push * continue with phonemes * up * finish PR * up * add example * rename * finish * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * correct split * finalize Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
32295b15a1
commit
c44d3675c2
@@ -45,6 +45,8 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
|
|||||||
[[autodoc]] Wav2Vec2CTCTokenizer
|
[[autodoc]] Wav2Vec2CTCTokenizer
|
||||||
- __call__
|
- __call__
|
||||||
- save_vocabulary
|
- save_vocabulary
|
||||||
|
- decode
|
||||||
|
- batch_decode
|
||||||
|
|
||||||
## Wav2Vec2FeatureExtractor
|
## Wav2Vec2FeatureExtractor
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Hubert model configuration"""
|
""" Hubert model configuration"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@@ -248,3 +250,7 @@ class HubertConfig(PretrainedConfig):
|
|||||||
# ctc loss
|
# ctc loss
|
||||||
self.ctc_loss_reduction = ctc_loss_reduction
|
self.ctc_loss_reduction = ctc_loss_reduction
|
||||||
self.ctc_zero_infinity = ctc_zero_infinity
|
self.ctc_zero_infinity = ctc_zero_infinity
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs_to_logits_ratio(self):
|
||||||
|
return math.prod(self.conv_stride)
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" SEW model configuration"""
|
""" SEW model configuration"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@@ -243,3 +245,7 @@ class SEWConfig(PretrainedConfig):
|
|||||||
# sequence classification
|
# sequence classification
|
||||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||||
self.classifier_proj_size = classifier_proj_size
|
self.classifier_proj_size = classifier_proj_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs_to_logits_ratio(self):
|
||||||
|
return math.prod(self.conv_stride)
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" SEW-D model configuration"""
|
""" SEW-D model configuration"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@@ -279,3 +281,7 @@ class SEWDConfig(PretrainedConfig):
|
|||||||
# sequence classification
|
# sequence classification
|
||||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||||
self.classifier_proj_size = classifier_proj_size
|
self.classifier_proj_size = classifier_proj_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs_to_logits_ratio(self):
|
||||||
|
return math.prod(self.conv_stride)
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" UniSpeech model configuration"""
|
""" UniSpeech model configuration"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@@ -289,3 +291,7 @@ class UniSpeechConfig(PretrainedConfig):
|
|||||||
|
|
||||||
# pretraining loss
|
# pretraining loss
|
||||||
self.replace_prob = replace_prob
|
self.replace_prob = replace_prob
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs_to_logits_ratio(self):
|
||||||
|
return math.prod(self.conv_stride)
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" UniSpeechSat model configuration"""
|
""" UniSpeechSat model configuration"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@@ -306,3 +308,7 @@ class UniSpeechSatConfig(PretrainedConfig):
|
|||||||
self.tdnn_kernel = list(tdnn_kernel)
|
self.tdnn_kernel = list(tdnn_kernel)
|
||||||
self.tdnn_dilation = list(tdnn_dilation)
|
self.tdnn_dilation = list(tdnn_dilation)
|
||||||
self.xvector_output_dim = xvector_output_dim
|
self.xvector_output_dim = xvector_output_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs_to_logits_ratio(self):
|
||||||
|
return math.prod(self.conv_stride)
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Wav2Vec2 model configuration"""
|
""" Wav2Vec2 model configuration"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@@ -329,3 +331,7 @@ class Wav2Vec2Config(PretrainedConfig):
|
|||||||
self.tdnn_kernel = list(tdnn_kernel)
|
self.tdnn_kernel = list(tdnn_kernel)
|
||||||
self.tdnn_dilation = list(tdnn_dilation)
|
self.tdnn_dilation = list(tdnn_dilation)
|
||||||
self.xvector_output_dim = xvector_output_dim
|
self.xvector_output_dim = xvector_output_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs_to_logits_ratio(self):
|
||||||
|
return math.prod(self.conv_stride)
|
||||||
|
|||||||
@@ -18,12 +18,22 @@ import json
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings
|
from ...file_utils import (
|
||||||
|
ModelOutput,
|
||||||
|
PaddingStrategy,
|
||||||
|
TensorType,
|
||||||
|
add_end_docstrings,
|
||||||
|
is_flax_available,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
to_py_obj,
|
||||||
|
)
|
||||||
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
|
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
|
||||||
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -32,6 +42,15 @@ from ...utils import logging
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
if is_flax_available():
|
||||||
|
import jax.numpy as jnp # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {
|
VOCAB_FILES_NAMES = {
|
||||||
"vocab_file": "vocab.json",
|
"vocab_file": "vocab.json",
|
||||||
"tokenizer_config_file": "tokenizer_config.json",
|
"tokenizer_config_file": "tokenizer_config.json",
|
||||||
@@ -79,6 +98,28 @@ WAV2VEC2_KWARGS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Wav2Vec2CTCTokenizerOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Output type of [` Wav2Vec2CTCTokenizer`], with transcription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (list of `str` or `str`):
|
||||||
|
Decoded logits in text from. Usually the speech transcription.
|
||||||
|
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
||||||
|
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
|
||||||
|
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
|
||||||
|
produced text.
|
||||||
|
word_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
||||||
|
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
|
||||||
|
can be used to compute time stamps for each word.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: Union[List[str], str]
|
||||||
|
char_offsets: List[Dict[str, Union[float, str]]] = None
|
||||||
|
word_offsets: List[Dict[str, Union[float, str]]] = None
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -121,6 +162,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
unk_token="<unk>",
|
unk_token="<unk>",
|
||||||
pad_token="<pad>",
|
pad_token="<pad>",
|
||||||
word_delimiter_token="|",
|
word_delimiter_token="|",
|
||||||
|
replace_word_delimiter_char=" ",
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -131,12 +173,14 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
do_lower_case=do_lower_case,
|
do_lower_case=do_lower_case,
|
||||||
word_delimiter_token=word_delimiter_token,
|
word_delimiter_token=word_delimiter_token,
|
||||||
|
replace_word_delimiter_char=replace_word_delimiter_char,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._word_delimiter_token = word_delimiter_token
|
self._word_delimiter_token = word_delimiter_token
|
||||||
|
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
|
self.replace_word_delimiter_char = replace_word_delimiter_char
|
||||||
|
|
||||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||||
self.encoder = json.load(vocab_handle)
|
self.encoder = json.load(vocab_handle)
|
||||||
@@ -204,31 +248,106 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def convert_tokens_to_string(
|
def convert_tokens_to_string(
|
||||||
self, tokens: List[str], group_tokens: bool = True, spaces_between_special_tokens: bool = False
|
self,
|
||||||
) -> str:
|
tokens: List[str],
|
||||||
|
group_tokens: bool = True,
|
||||||
|
spaces_between_special_tokens: bool = False,
|
||||||
|
output_char_offsets: bool = False,
|
||||||
|
output_word_offsets: bool = False,
|
||||||
|
) -> Dict[str, Union[str, float]]:
|
||||||
"""
|
"""
|
||||||
Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
|
Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
|
||||||
"""
|
"""
|
||||||
# group same tokens into non-repeating tokens in CTC style decoding
|
# group same tokens into non-repeating tokens in CTC style decoding
|
||||||
if group_tokens:
|
if group_tokens:
|
||||||
tokens = [token_group[0] for token_group in groupby(tokens)]
|
chars, char_repetitions = zip(*((token, len(list(group_iter))) for token, group_iter in groupby(tokens)))
|
||||||
|
else:
|
||||||
|
chars = tokens
|
||||||
|
char_repetitions = len(tokens) * [1]
|
||||||
|
|
||||||
# filter self.pad_token which is used as CTC-blank token
|
# filter self.pad_token which is used as CTC-blank token
|
||||||
filtered_tokens = list(filter(lambda token: token != self.pad_token, tokens))
|
processed_chars = list(filter(lambda char: char != self.pad_token, chars))
|
||||||
|
|
||||||
if spaces_between_special_tokens:
|
|
||||||
join_token = " "
|
|
||||||
else:
|
|
||||||
join_token = ""
|
|
||||||
|
|
||||||
# replace delimiter token
|
# replace delimiter token
|
||||||
string = join_token.join(
|
processed_chars = [
|
||||||
[" " if token == self.word_delimiter_token else token for token in filtered_tokens]
|
self.replace_word_delimiter_char if char == self.word_delimiter_token else char for char in processed_chars
|
||||||
).strip()
|
]
|
||||||
|
|
||||||
|
# retrieve offsets
|
||||||
|
char_offsets = word_offsets = None
|
||||||
|
if output_char_offsets or output_word_offsets:
|
||||||
|
char_offsets = self._compute_offsets(char_repetitions, chars, self.pad_token)
|
||||||
|
|
||||||
|
if len(char_offsets) != len(processed_chars):
|
||||||
|
raise ValueError(
|
||||||
|
f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}"
|
||||||
|
" have to be of the same length, but are: "
|
||||||
|
f"`len(offsets)`: {len(char_offsets)} and `len(processed_tokens)`:"
|
||||||
|
f" {len(processed_chars)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# set tokens to correct processed token
|
||||||
|
for i, char in enumerate(processed_chars):
|
||||||
|
char_offsets[i]["char"] = char
|
||||||
|
|
||||||
|
# retrieve word offsets from character offsets
|
||||||
|
word_offsets = None
|
||||||
|
if output_word_offsets:
|
||||||
|
word_offsets = self._get_word_offsets(char_offsets, self.replace_word_delimiter_char)
|
||||||
|
|
||||||
|
# join to string
|
||||||
|
join_char = " " if spaces_between_special_tokens else ""
|
||||||
|
string = join_char.join(processed_chars).strip()
|
||||||
|
|
||||||
if self.do_lower_case:
|
if self.do_lower_case:
|
||||||
string = string.lower()
|
string = string.lower()
|
||||||
return string
|
|
||||||
|
return {"text": string, "char_offsets": char_offsets, "word_offsets": word_offsets}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_offsets(
|
||||||
|
char_repetitions: List[int], chars: List[str], ctc_token: int
|
||||||
|
) -> List[Dict[str, Union[str, int]]]:
|
||||||
|
end_indices = np.asarray(char_repetitions).cumsum()
|
||||||
|
start_indices = np.concatenate(([0], end_indices[:-1]))
|
||||||
|
|
||||||
|
offsets = [
|
||||||
|
{"char": t, "start_offset": s, "end_offset": e} for t, s, e in zip(chars, start_indices, end_indices)
|
||||||
|
]
|
||||||
|
|
||||||
|
# filter out CTC token
|
||||||
|
offsets = list(filter(lambda offsets: offsets["char"] != ctc_token, offsets))
|
||||||
|
return offsets
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_word_offsets(
|
||||||
|
offsets: Dict[str, Union[str, float]], word_delimiter_char: str = " "
|
||||||
|
) -> Dict[str, Union[str, float]]:
|
||||||
|
word_offsets = []
|
||||||
|
final_offset_idx = len(offsets) - 1
|
||||||
|
|
||||||
|
for i, offset in enumerate(offsets):
|
||||||
|
# define previous, next and current char
|
||||||
|
char = offset["char"]
|
||||||
|
prev_char = offsets[i - 1]["char"] if i > 0 else None
|
||||||
|
next_char = offsets[i + 1]["char"] if i < final_offset_idx else None
|
||||||
|
|
||||||
|
# derive whether word begins, ends and whether current char is in word
|
||||||
|
word_begin = (i == 0 and char != word_delimiter_char) or (prev_char == word_delimiter_char)
|
||||||
|
word_end = (i == final_offset_idx and char != word_delimiter_char) or (next_char == word_delimiter_char)
|
||||||
|
char_is_in_word = char != word_delimiter_char
|
||||||
|
|
||||||
|
if word_begin:
|
||||||
|
word_offset = {"word": "", "start_offset": offset["start_offset"]}
|
||||||
|
|
||||||
|
if word_end:
|
||||||
|
word_offset["end_offset"] = offset["end_offset"]
|
||||||
|
word_offsets.append(word_offset)
|
||||||
|
|
||||||
|
if char_is_in_word:
|
||||||
|
word_offset["word"] += offset["char"]
|
||||||
|
|
||||||
|
return word_offsets
|
||||||
|
|
||||||
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
||||||
if is_split_into_words:
|
if is_split_into_words:
|
||||||
@@ -242,6 +361,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
clean_up_tokenization_spaces: bool = True,
|
clean_up_tokenization_spaces: bool = True,
|
||||||
group_tokens: bool = True,
|
group_tokens: bool = True,
|
||||||
spaces_between_special_tokens: bool = False,
|
spaces_between_special_tokens: bool = False,
|
||||||
|
output_word_offsets: Optional[bool] = False,
|
||||||
|
output_char_offsets: Optional[bool] = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
|
special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
|
||||||
@@ -256,16 +377,210 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
continue
|
continue
|
||||||
result.append(token)
|
result.append(token)
|
||||||
|
|
||||||
text = self.convert_tokens_to_string(
|
string_output = self.convert_tokens_to_string(
|
||||||
result, group_tokens=group_tokens, spaces_between_special_tokens=spaces_between_special_tokens
|
result,
|
||||||
|
group_tokens=group_tokens,
|
||||||
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
|
output_word_offsets=output_word_offsets,
|
||||||
|
output_char_offsets=output_char_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
text = string_output["text"]
|
||||||
|
|
||||||
if clean_up_tokenization_spaces:
|
if clean_up_tokenization_spaces:
|
||||||
clean_text = self.clean_up_tokenization(text)
|
text = self.clean_up_tokenization(text)
|
||||||
return clean_text
|
|
||||||
|
if output_word_offsets or output_char_offsets:
|
||||||
|
return Wav2Vec2CTCTokenizerOutput(
|
||||||
|
text=text,
|
||||||
|
char_offsets=string_output["char_offsets"],
|
||||||
|
word_offsets=string_output["word_offsets"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
# overwritten from `tokenization_utils_base.py` because tokenizer can output
|
||||||
|
# `ModelOutput` which should not be a list for batched output and
|
||||||
|
# because we need docs for `output_char_offsets` here
|
||||||
|
def batch_decode(
|
||||||
|
self,
|
||||||
|
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
clean_up_tokenization_spaces: bool = True,
|
||||||
|
output_char_offsets: bool = False,
|
||||||
|
output_word_offsets: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Convert a list of lists of token ids into a list of strings by calling decode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||||
|
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||||
|
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to remove special tokens in the decoding.
|
||||||
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to clean up the tokenization spaces.
|
||||||
|
output_char_offsets (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to output character offsets. Character offsets can be used in combination with the
|
||||||
|
sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||||
|
understand how to make use of `output_word_offsets`.
|
||||||
|
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
output_word_offsets (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
|
||||||
|
and model downsampling rate to compute the time-stamps of transcribed words.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||||
|
understand how to make use of `output_word_offsets`.
|
||||||
|
[`~model.wav2vec2.tokenization_wav2vec2.batch_decode`] works the same way with batched output.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
kwargs (additional keyword arguments, *optional*):
|
||||||
|
Will be passed to the underlying model specific decode method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded
|
||||||
|
sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when
|
||||||
|
`output_char_offsets == True` or `output_word_offsets == True`.
|
||||||
|
"""
|
||||||
|
batch_decoded = [
|
||||||
|
self.decode(
|
||||||
|
seq,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
output_char_offsets=output_char_offsets,
|
||||||
|
output_word_offsets=output_word_offsets,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for seq in sequences
|
||||||
|
]
|
||||||
|
if output_char_offsets or output_word_offsets:
|
||||||
|
# transform list of dicts to dict of lists
|
||||||
|
return Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in batch_decoded] for k in batch_decoded[0]})
|
||||||
|
|
||||||
|
return batch_decoded
|
||||||
|
|
||||||
|
# overwritten from `tokenization_utils_base.py` because we need docs for `output_char_offsets`
|
||||||
|
# and `output_word_offsets` here
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
clean_up_tokenization_spaces: bool = True,
|
||||||
|
output_char_offsets: bool = False,
|
||||||
|
output_word_offsets: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
||||||
|
tokens and clean up tokenization spaces.
|
||||||
|
|
||||||
|
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||||
|
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||||
|
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to remove special tokens in the decoding.
|
||||||
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to clean up the tokenization spaces.
|
||||||
|
output_char_offsets (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to output character offsets. Character offsets can be used in combination with the
|
||||||
|
sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||||
|
understand how to make use of `output_word_offsets`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
output_word_offsets (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
|
||||||
|
and model downsampling rate to compute the time-stamps of transcribed words.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Please take a look at the example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||||
|
understand how to make use of `output_word_offsets`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
kwargs (additional keyword arguments, *optional*):
|
||||||
|
Will be passed to the underlying model specific decode method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded
|
||||||
|
sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when
|
||||||
|
`output_char_offsets == True` or `output_word_offsets == True`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> # Let's see how to retrieve time steps for a model
|
||||||
|
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import datasets
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> # import model, feature extractor, tokenizer
|
||||||
|
>>> model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
|
||||||
|
>>> # load first sample of English common_voice
|
||||||
|
>>> dataset = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||||
|
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||||
|
>>> dataset_iter = iter(dataset)
|
||||||
|
>>> sample = next(dataset_iter)
|
||||||
|
|
||||||
|
>>> # forward sample through model to get greedily predicted transcription ids
|
||||||
|
>>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||||
|
>>> logits = model(input_values).logits[0]
|
||||||
|
>>> pred_ids = torch.argmax(logits, axis=-1)
|
||||||
|
|
||||||
|
>>> # retrieve word stamps (analogous commands for `output_char_offsets`)
|
||||||
|
>>> outputs = tokenizer.decode(pred_ids, output_word_offsets=True)
|
||||||
|
>>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
|
||||||
|
>>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
|
||||||
|
|
||||||
|
>>> word_offsets = [
|
||||||
|
... {
|
||||||
|
... "word": d["word"],
|
||||||
|
... "start_time": d["start_offset"] * time_offset,
|
||||||
|
... "end_time": d["end_offset"] * time_offset,
|
||||||
|
... }
|
||||||
|
... for d in outputs.word_offsets
|
||||||
|
... ]
|
||||||
|
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
|
||||||
|
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
|
||||||
|
>>> word_offset
|
||||||
|
>>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES',
|
||||||
|
>>> # 'start_time': 1.64, 'end_time': 1.90}, {'word': 'MILISANDRA',
|
||||||
|
>>> # 'start_time': 2.26, 'end_time': 2.9}, {'word': 'LOOK', 'start_time': 3.0, 'end_time': 3.16}, ...
|
||||||
|
```"""
|
||||||
|
# Convert inputs to python lists
|
||||||
|
token_ids = to_py_obj(token_ids)
|
||||||
|
|
||||||
|
return self._decode(
|
||||||
|
token_ids=token_ids,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
output_char_offsets=output_char_offsets,
|
||||||
|
output_word_offsets=output_word_offsets,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
if not os.path.isdir(save_directory):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||||
@@ -294,7 +609,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
Returns:
|
Returns:
|
||||||
`int`: The number of tokens actually added to the vocabulary.
|
`int`: The number of tokens actually added to the vocabulary.
|
||||||
|
|
||||||
Examples:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
||||||
@@ -551,6 +866,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
if self.do_lower_case:
|
if self.do_lower_case:
|
||||||
string = string.lower()
|
string = string.lower()
|
||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
|
|||||||
@@ -17,10 +17,20 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from ...file_utils import requires_backends
|
import numpy as np
|
||||||
|
|
||||||
|
from ...file_utils import (
|
||||||
|
ModelOutput,
|
||||||
|
is_flax_available,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
requires_backends,
|
||||||
|
to_py_obj,
|
||||||
|
)
|
||||||
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
|
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
|
||||||
from ...tokenization_utils_base import AddedToken
|
from ...tokenization_utils_base import AddedToken
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -29,6 +39,15 @@ from ...utils import logging
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
if is_flax_available():
|
||||||
|
import jax.numpy as jnp # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {
|
VOCAB_FILES_NAMES = {
|
||||||
"vocab_file": "vocab.json",
|
"vocab_file": "vocab.json",
|
||||||
"tokenizer_config_file": "tokenizer_config.json",
|
"tokenizer_config_file": "tokenizer_config.json",
|
||||||
@@ -47,6 +66,24 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize}
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Output type of [` Wav2Vec2PhonemeCTCTokenizer`], with transcription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (list of `str` or `str`):
|
||||||
|
Decoded logits in text from. Usually the speech transcription.
|
||||||
|
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`):
|
||||||
|
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
|
||||||
|
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
|
||||||
|
produced text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: Union[List[str], str]
|
||||||
|
char_offsets: List[Dict[str, Union[float, str]]] = None
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -284,24 +321,69 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
|||||||
group_tokens: bool = True,
|
group_tokens: bool = True,
|
||||||
spaces_between_special_tokens: bool = False,
|
spaces_between_special_tokens: bool = False,
|
||||||
filter_word_delimiter_token: bool = True,
|
filter_word_delimiter_token: bool = True,
|
||||||
|
output_char_offsets: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
|
Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
|
||||||
"""
|
"""
|
||||||
# group same tokens into non-repeating tokens in CTC style decoding
|
# group same tokens into non-repeating tokens in CTC style decoding
|
||||||
if group_tokens:
|
if group_tokens:
|
||||||
tokens = [token_group[0] for token_group in groupby(tokens)]
|
chars, char_repetitions = zip(*((token, len(list(group_iter))) for token, group_iter in groupby(tokens)))
|
||||||
|
else:
|
||||||
|
chars = tokens
|
||||||
|
char_repetitions = len(tokens) * [1]
|
||||||
|
|
||||||
# filter self.pad_token which is used as CTC-blank token
|
# filter self.pad_token which is used as CTC-blank token
|
||||||
filtered_tokens = list(filter(lambda token: token != self.pad_token, tokens))
|
processed_chars = list(filter(lambda char: char != self.pad_token, chars))
|
||||||
|
|
||||||
# also filter self.word_delimiter_token if not not
|
# also filter self.word_delimiter_token if not not
|
||||||
if filter_word_delimiter_token and self.word_delimiter_token is not None:
|
if filter_word_delimiter_token and self.word_delimiter_token is not None:
|
||||||
filtered_tokens = list(filter(lambda token: token != self.word_delimiter_token, filtered_tokens))
|
processed_chars = list(filter(lambda token: token != self.word_delimiter_token, processed_chars))
|
||||||
|
|
||||||
string = " ".join(filtered_tokens).strip()
|
# retrieve offsets
|
||||||
|
char_offsets = None
|
||||||
|
if output_char_offsets:
|
||||||
|
word_delimiter_token_for_offsets = (
|
||||||
|
self.word_delimiter_token if filter_word_delimiter_token is True else None
|
||||||
|
)
|
||||||
|
char_offsets = self._compute_offsets(
|
||||||
|
char_repetitions, chars, self.pad_token, word_delimiter_token=word_delimiter_token_for_offsets
|
||||||
|
)
|
||||||
|
|
||||||
return string
|
if len(char_offsets) != len(processed_chars):
|
||||||
|
raise ValueError(
|
||||||
|
f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}"
|
||||||
|
f" have to be of the same length, but are: `len(offsets)`: "
|
||||||
|
f"{len(char_offsets)} and `len(processed_tokens)`: {len(processed_chars)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# set tokens to correct processed token
|
||||||
|
for i, char in enumerate(processed_chars):
|
||||||
|
char_offsets[i]["char"] = char
|
||||||
|
|
||||||
|
string = " ".join(processed_chars).strip()
|
||||||
|
|
||||||
|
return {"text": string, "char_offsets": char_offsets}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_offsets(
|
||||||
|
char_repetitions: List[int], chars: List[str], ctc_token: int, word_delimiter_token: Optional[int] = None
|
||||||
|
) -> List[Dict[str, Union[str, int]]]:
|
||||||
|
end_indices = np.asarray(char_repetitions).cumsum()
|
||||||
|
start_indices = np.concatenate(([0], end_indices[:-1]))
|
||||||
|
|
||||||
|
offsets = [
|
||||||
|
{"char": t, "start_offset": s, "end_offset": e} for t, s, e in zip(chars, start_indices, end_indices)
|
||||||
|
]
|
||||||
|
|
||||||
|
# filter out CTC token
|
||||||
|
offsets = list(filter(lambda offsets: offsets["char"] != ctc_token, offsets))
|
||||||
|
|
||||||
|
# filter out word delimiter token if necessary
|
||||||
|
if word_delimiter_token is not None:
|
||||||
|
offsets = list(filter(lambda offsets: offsets["char"] != word_delimiter_token, offsets))
|
||||||
|
|
||||||
|
return offsets
|
||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
self,
|
self,
|
||||||
@@ -311,6 +393,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
|||||||
group_tokens: bool = True,
|
group_tokens: bool = True,
|
||||||
filter_word_delimiter_token: bool = True,
|
filter_word_delimiter_token: bool = True,
|
||||||
spaces_between_special_tokens: bool = False,
|
spaces_between_special_tokens: bool = False,
|
||||||
|
output_char_offsets: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
special _decode function is needed for Wav2Vec2PhonemeTokenizer because added tokens should be treated exactly
|
special _decode function is needed for Wav2Vec2PhonemeTokenizer because added tokens should be treated exactly
|
||||||
@@ -325,19 +408,137 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
|
|||||||
continue
|
continue
|
||||||
result.append(token)
|
result.append(token)
|
||||||
|
|
||||||
text = self.convert_tokens_to_string(
|
string_output = self.convert_tokens_to_string(
|
||||||
result,
|
result,
|
||||||
group_tokens=group_tokens,
|
group_tokens=group_tokens,
|
||||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
filter_word_delimiter_token=filter_word_delimiter_token,
|
filter_word_delimiter_token=filter_word_delimiter_token,
|
||||||
|
output_char_offsets=output_char_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
text = string_output["text"]
|
||||||
|
|
||||||
if clean_up_tokenization_spaces:
|
if clean_up_tokenization_spaces:
|
||||||
clean_text = self.clean_up_tokenization(text)
|
text = self.clean_up_tokenization(text)
|
||||||
return clean_text
|
|
||||||
|
if output_char_offsets:
|
||||||
|
return Wav2Vec2PhonemeCTCTokenizerOutput(text=text, char_offsets=string_output["char_offsets"])
|
||||||
else:
|
else:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
# overwritten from `tokenization_utils_base.py` because we need docs for `output_char_offsets` here
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
clean_up_tokenization_spaces: bool = True,
|
||||||
|
output_char_offsets: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
||||||
|
tokens and clean up tokenization spaces.
|
||||||
|
|
||||||
|
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||||
|
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||||
|
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to remove special tokens in the decoding.
|
||||||
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to clean up the tokenization spaces.
|
||||||
|
output_char_offsets (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to output character offsets. Character offsets can be used in combination with the
|
||||||
|
sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||||
|
understand how to make use of `output_word_offsets`.
|
||||||
|
[`~model.wav2vec2_phoneme.tokenization_wav2vec2_phoneme.batch_decode`] works the same way with
|
||||||
|
phonemes.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
kwargs (additional keyword arguments, *optional*):
|
||||||
|
Will be passed to the underlying model specific decode method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str` or [`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`]: The decoded
|
||||||
|
sentence. Will be a [`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`]
|
||||||
|
when `output_char_offsets == True`.
|
||||||
|
"""
|
||||||
|
# Convert inputs to python lists
|
||||||
|
token_ids = to_py_obj(token_ids)
|
||||||
|
|
||||||
|
return self._decode(
|
||||||
|
token_ids=token_ids,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
output_char_offsets=output_char_offsets,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# overwritten from `tokenization_utils_base.py` because tokenizer can output
|
||||||
|
# `ModelOutput` which should not be a list for batched output and because
|
||||||
|
# we need docs for `output_char_offsets` here
|
||||||
|
def batch_decode(
|
||||||
|
self,
|
||||||
|
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
clean_up_tokenization_spaces: bool = True,
|
||||||
|
output_char_offsets: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Convert a list of lists of token ids into a list of strings by calling decode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
|
||||||
|
List of tokenized input ids. Can be obtained using the `__call__` method.
|
||||||
|
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to remove special tokens in the decoding.
|
||||||
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to clean up the tokenization spaces.
|
||||||
|
output_char_offsets (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to output character offsets. Character offsets can be used in combination with the
|
||||||
|
sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Please take a look at the Example of [`~models.wav2vec2.tokenization_wav2vec2.decode`] to better
|
||||||
|
understand how to make use of `output_word_offsets`.
|
||||||
|
[`~model.wav2vec2_phoneme.tokenization_wav2vec2_phoneme.batch_decode`] works analogous with phonemes
|
||||||
|
and batched output.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
kwargs (additional keyword arguments, *optional*):
|
||||||
|
Will be passed to the underlying model specific decode method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]` or [`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`]: The
|
||||||
|
decoded sentence. Will be a
|
||||||
|
[`~models.wav2vec2.tokenization_wav2vec2_phoneme.Wav2Vec2PhonemeCTCTokenizerOutput`] when
|
||||||
|
`output_char_offsets == True`.
|
||||||
|
"""
|
||||||
|
batch_decoded = [
|
||||||
|
self.decode(
|
||||||
|
seq,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
output_char_offsets=output_char_offsets,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for seq in sequences
|
||||||
|
]
|
||||||
|
if output_char_offsets:
|
||||||
|
# transform list of dicts to dict of lists
|
||||||
|
return Wav2Vec2PhonemeCTCTokenizerOutput({k: [d[k] for d in batch_decoded] for k in batch_decoded[0]})
|
||||||
|
|
||||||
|
return batch_decoded
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
if not os.path.isdir(save_directory):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" WavLM model configuration"""
|
""" WavLM model configuration"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@@ -330,3 +332,7 @@ class WavLMConfig(PretrainedConfig):
|
|||||||
self.tdnn_kernel = list(tdnn_kernel)
|
self.tdnn_kernel = list(tdnn_kernel)
|
||||||
self.tdnn_dilation = list(tdnn_dilation)
|
self.tdnn_dilation = list(tdnn_dilation)
|
||||||
self.xvector_output_dim = xvector_output_dim
|
self.xvector_output_dim = xvector_output_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs_to_logits_ratio(self):
|
||||||
|
return math.prod(self.conv_stride)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from transformers import (
|
|||||||
Wav2Vec2CTCTokenizer,
|
Wav2Vec2CTCTokenizer,
|
||||||
Wav2Vec2Tokenizer,
|
Wav2Vec2Tokenizer,
|
||||||
)
|
)
|
||||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2CTCTokenizerOutput
|
||||||
from transformers.testing_utils import require_torch, slow
|
from transformers.testing_utils import require_torch, slow
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
@@ -422,27 +422,16 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def test_tokenizer_decode_special(self):
|
def test_tokenizer_decode_special(self):
|
||||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
sample_ids = [
|
sample_ids = [
|
||||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||||
]
|
]
|
||||||
sample_ids_2 = [
|
sample_ids_2 = [
|
||||||
[11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
|
[11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||||
[
|
[24, 22, 5, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.word_delimiter_token_id],
|
||||||
24,
|
|
||||||
22,
|
|
||||||
5,
|
|
||||||
tokenizer.pad_token_id,
|
|
||||||
tokenizer.pad_token_id,
|
|
||||||
tokenizer.pad_token_id,
|
|
||||||
tokenizer.word_delimiter_token_id,
|
|
||||||
24,
|
|
||||||
22,
|
|
||||||
5,
|
|
||||||
77,
|
|
||||||
tokenizer.word_delimiter_token_id,
|
|
||||||
],
|
|
||||||
]
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||||
batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
|
batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
|
||||||
@@ -454,27 +443,12 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
tokenizer.add_tokens(["!", "?"])
|
tokenizer.add_tokens(["!", "?"])
|
||||||
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
sample_ids = [
|
sample_ids = [
|
||||||
[
|
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||||
11,
|
|
||||||
5,
|
|
||||||
15,
|
|
||||||
tokenizer.pad_token_id,
|
|
||||||
15,
|
|
||||||
8,
|
|
||||||
98,
|
|
||||||
32,
|
|
||||||
32,
|
|
||||||
33,
|
|
||||||
tokenizer.word_delimiter_token_id,
|
|
||||||
32,
|
|
||||||
32,
|
|
||||||
33,
|
|
||||||
34,
|
|
||||||
34,
|
|
||||||
],
|
|
||||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
||||||
]
|
]
|
||||||
|
# fmt: on
|
||||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||||
|
|
||||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||||
@@ -499,6 +473,187 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||||
self.assertEqual(sent, expected_sent)
|
self.assertEqual(sent, expected_sent)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_from_offsets(offsets, key):
|
||||||
|
retrieved_list = [d[key] for d in offsets]
|
||||||
|
return retrieved_list
|
||||||
|
|
||||||
|
def test_offsets(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# HEEEEE||LLL<pad>LO<unk> => HE LLO<unk>
|
||||||
|
# 1H + 5E + 2| + 3L + 1<pad> + 1L + 1O + 1<unk>
|
||||||
|
sample_ids = [11, 5, 5, 5, 5, 5, 4, 4, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
outputs_char = tokenizer.decode(sample_ids, output_char_offsets=True)
|
||||||
|
# check Wav2Vec2CTCTokenizerOutput keys for char
|
||||||
|
self.assertTrue(len(outputs_char.keys()), 2)
|
||||||
|
self.assertTrue("text" in outputs_char)
|
||||||
|
self.assertTrue("char_offsets" in outputs_char)
|
||||||
|
self.assertTrue(isinstance(outputs_char, Wav2Vec2CTCTokenizerOutput))
|
||||||
|
|
||||||
|
outputs_word = tokenizer.decode(sample_ids, output_word_offsets=True)
|
||||||
|
# check Wav2Vec2CTCTokenizerOutput keys for word
|
||||||
|
self.assertTrue(len(outputs_word.keys()), 2)
|
||||||
|
self.assertTrue("text" in outputs_word)
|
||||||
|
self.assertTrue("word_offsets" in outputs_word)
|
||||||
|
self.assertTrue(isinstance(outputs_word, Wav2Vec2CTCTokenizerOutput))
|
||||||
|
|
||||||
|
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, output_word_offsets=True)
|
||||||
|
# check Wav2Vec2CTCTokenizerOutput keys for both
|
||||||
|
self.assertTrue(len(outputs.keys()), 3)
|
||||||
|
self.assertTrue("text" in outputs)
|
||||||
|
self.assertTrue("char_offsets" in outputs)
|
||||||
|
self.assertTrue("word_offsets" in outputs)
|
||||||
|
self.assertTrue(isinstance(outputs, Wav2Vec2CTCTokenizerOutput))
|
||||||
|
|
||||||
|
# check that order of chars is correct and identical for both outputs
|
||||||
|
self.assertEqual("".join(self.get_from_offsets(outputs["char_offsets"], "char")), outputs.text)
|
||||||
|
self.assertEqual(
|
||||||
|
self.get_from_offsets(outputs["char_offsets"], "char"), ["H", "E", " ", "L", "L", "O", "<unk>"]
|
||||||
|
)
|
||||||
|
self.assertListEqual(
|
||||||
|
self.get_from_offsets(outputs["char_offsets"], "char"),
|
||||||
|
self.get_from_offsets(outputs_char["char_offsets"], "char"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that order of words is correct and identical to both outputs
|
||||||
|
self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text)
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["HE", "LLO<unk>"])
|
||||||
|
self.assertListEqual(
|
||||||
|
self.get_from_offsets(outputs["word_offsets"], "word"),
|
||||||
|
self.get_from_offsets(outputs_word["word_offsets"], "word"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that offsets are actually correct for char
|
||||||
|
# 0 is H, 1 is E, 6 is | (" "), 8 is 1st L, 12 is 2nd L, 13 is O, 14 is <unk>
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["char_offsets"], "start_offset"), [0, 1, 6, 8, 12, 13, 14])
|
||||||
|
# 1 is H, 6 is E, 8 is | (" "), 11 is 1st L (note due to <pad>
|
||||||
|
# different begin of 2nd L), 13 is 2nd L, 14 is O, 15 is <unk>
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["char_offsets"], "end_offset"), [1, 6, 8, 11, 13, 14, 15])
|
||||||
|
|
||||||
|
# check that offsets are actually correct for word
|
||||||
|
# H is at 1st position of first word, first L is at 8th position of second word
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 8])
|
||||||
|
# last E is at 6th position of first word, first L is at last (15th) position of second word
|
||||||
|
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [6, 15])
|
||||||
|
|
||||||
|
def test_offsets_batch(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
def check_list_tuples_equal(outputs_batch, outputs_list):
|
||||||
|
self.assertTrue(isinstance(outputs_batch, Wav2Vec2CTCTokenizerOutput))
|
||||||
|
self.assertTrue(isinstance(outputs_list[0], Wav2Vec2CTCTokenizerOutput))
|
||||||
|
|
||||||
|
# transform list to ModelOutput
|
||||||
|
outputs_batch_2 = Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in outputs_list] for k in outputs_list[0]})
|
||||||
|
|
||||||
|
self.assertListEqual(outputs_batch["text"], outputs_batch_2["text"])
|
||||||
|
|
||||||
|
def recursive_check(list_or_dict_1, list_or_dict_2):
|
||||||
|
if isinstance(list_or_dict_1, list):
|
||||||
|
[recursive_check(l1, l2) for l1, l2 in zip(list_or_dict_1, list_or_dict_2)]
|
||||||
|
self.assertEqual(list_or_dict_1, list_or_dict_2)
|
||||||
|
|
||||||
|
if "char_offsets" in outputs_batch:
|
||||||
|
recursive_check(outputs_batch["char_offsets"], outputs_batch_2["char_offsets"])
|
||||||
|
|
||||||
|
if "word_offsets" in outputs_batch:
|
||||||
|
recursive_check(outputs_batch["word_offsets"], outputs_batch_2["word_offsets"])
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
sample_ids = [
|
||||||
|
[11, 5, 15, tokenizer.pad_token_id, 15, 4, 8, 98, 32, 32, 32, 32, 4, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||||
|
[24, 22, 5, tokenizer.word_delimiter_token_id, tokenizer.word_delimiter_token_id, 24, 22, 22, 22, 4, 5, 77, tokenizer.pad_token_id, 22, 22, 4, 34, 34, 34, 34],
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# We assume that `decode` works as expected. All we will check now is
|
||||||
|
# the output type is correct and the output is identical to `decode`
|
||||||
|
|
||||||
|
# char
|
||||||
|
outputs_char_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True)
|
||||||
|
outputs_char = [tokenizer.decode(ids, output_char_offsets=True) for ids in sample_ids]
|
||||||
|
check_list_tuples_equal(outputs_char_batch, outputs_char)
|
||||||
|
|
||||||
|
# word
|
||||||
|
outputs_word_batch = tokenizer.batch_decode(sample_ids, output_word_offsets=True)
|
||||||
|
outputs_word = [tokenizer.decode(ids, output_word_offsets=True) for ids in sample_ids]
|
||||||
|
check_list_tuples_equal(outputs_word_batch, outputs_word)
|
||||||
|
|
||||||
|
# both
|
||||||
|
outputs_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True, output_word_offsets=True)
|
||||||
|
outputs = [tokenizer.decode(ids, output_word_offsets=True, output_char_offsets=True) for ids in sample_ids]
|
||||||
|
check_list_tuples_equal(outputs_batch, outputs)
|
||||||
|
|
||||||
|
def test_offsets_integration(self):
|
||||||
|
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
# pred_ids correspond to the following code
|
||||||
|
# ```
|
||||||
|
# from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
|
||||||
|
# from datasets import load_dataset
|
||||||
|
# import datasets
|
||||||
|
# import torch
|
||||||
|
# model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
# feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
#
|
||||||
|
# ds = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||||
|
# ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||||
|
# ds_iter = iter(ds)
|
||||||
|
# sample = next(ds_iter)
|
||||||
|
#
|
||||||
|
# input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||||
|
# logits = model(input_values).logits
|
||||||
|
# pred_ids = torch.argmax(logits, axis=-1).cpu().tolist()
|
||||||
|
# ```
|
||||||
|
# fmt: off
|
||||||
|
pred_ids = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 11, 0, 0, 0, 22, 0, 0, 4, 4, 4, 14, 0, 0, 0, 0, 0, 8, 8, 0, 5, 5, 0, 12, 0, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 10, 0, 0, 0, 15, 0, 0, 10, 0, 0, 0, 12, 0, 0, 0, 0, 0, 7, 0, 9, 0, 0, 14, 0, 0, 0, 13, 0, 7, 0, 0, 4, 4, 0, 15, 8, 8, 0, 0, 8, 0, 26, 0, 0, 4, 4, 0, 0, 15, 0, 0, 0, 0, 0, 0, 10, 0, 26, 5, 5, 0, 4, 4, 0, 0, 12, 11, 0, 0, 5, 4, 4, 4, 0, 18, 0, 0, 0, 7, 9, 9, 0, 6, 0, 12, 12, 4, 4, 0, 6, 0, 0, 8, 0, 4, 4, 4, 0, 19, 0, 0, 8, 9, 9, 0, 0, 0, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 16, 16, 0, 0, 17, 5, 5, 5, 0, 4, 4, 4, 0, 0, 29, 29, 0, 0, 0, 0, 8, 11, 0, 9, 9, 0, 0, 0, 4, 4, 0, 12, 12, 0, 0, 0, 9, 0, 0, 0, 0, 0, 8, 18, 0, 0, 0, 4, 4, 0, 0, 8, 9, 0, 4, 4, 0, 6, 11, 5, 0, 4, 4, 0, 13, 13, 0, 0, 0, 10, 0, 0, 25, 0, 0, 6, 0, 4, 4, 0, 0, 0, 0, 7, 0, 0, 23, 0, 0, 4, 4, 0, 0, 0, 6, 11, 0, 5, 4, 4, 18, 0, 0, 0, 0, 0, 0, 7, 15, 0, 0, 0, 15, 15, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
|
||||||
|
|
||||||
|
# wav2vec2-base downsamples input audio by a factor of 320
|
||||||
|
# sampling rate for wav2vec2-base is 16_000
|
||||||
|
time_offset_wav2vec2_base = 320 / 16_000
|
||||||
|
|
||||||
|
expected_char_time_stamps_text = ['W', 'H', 'Y', ' ', 'D', 'O', 'E', 'S', ' ', 'M', 'I', 'L', 'I', 'S', 'A', 'N', 'D', 'R', 'A', ' ', 'L', 'O', 'O', 'K', ' ', 'L', 'I', 'K', 'E', ' ', 'S', 'H', 'E', ' ', 'W', 'A', 'N', 'T', 'S', ' ', 'T', 'O', ' ', 'C', 'O', 'N', 'S', 'U', 'M', 'E', ' ', 'J', 'O', 'H', 'N', ' ', 'S', 'N', 'O', 'W', ' ', 'O', 'N', ' ', 'T', 'H', 'E', ' ', 'R', 'I', 'V', 'T', ' ', 'A', 'P', ' ', 'T', 'H', 'E', ' ', 'W', 'A', 'L', 'L', ' ']
|
||||||
|
expected_char_time_stamps_start = [1.42, 1.44, 1.52, 1.58, 1.64, 1.76, 1.82, 1.88, 1.92, 2.26, 2.32, 2.4, 2.46, 2.54, 2.66, 2.7, 2.76, 2.84, 2.88, 2.94, 3.0, 3.02, 3.1, 3.14, 3.2, 3.28, 3.42, 3.46, 3.48, 3.54, 3.62, 3.64, 3.7, 3.72, 3.8, 3.88, 3.9, 3.96, 4.0, 4.04, 4.1, 4.16, 4.2, 4.28, 4.34, 4.36, 4.48, 4.66, 4.74, 4.76, 4.84, 4.94, 5.06, 5.08, 5.12, 5.22, 5.28, 5.38, 5.5, 5.52, 5.6, 5.68, 5.7, 5.74, 5.8, 5.82, 5.84, 5.88, 5.94, 6.04, 6.1, 6.16, 6.2, 6.32, 6.38, 6.44, 6.54, 6.56, 6.6, 6.62, 6.66, 6.8, 6.82, 6.9, 6.96]
|
||||||
|
expected_char_time_stamps_end = [1.44, 1.46, 1.54, 1.64, 1.66, 1.8, 1.86, 1.9, 2.06, 2.28, 2.34, 2.42, 2.48, 2.56, 2.68, 2.72, 2.78, 2.86, 2.9, 2.98, 3.02, 3.06, 3.12, 3.16, 3.24, 3.3, 3.44, 3.48, 3.52, 3.58, 3.64, 3.66, 3.72, 3.78, 3.82, 3.9, 3.94, 3.98, 4.04, 4.08, 4.12, 4.18, 4.26, 4.3, 4.36, 4.4, 4.52, 4.7, 4.76, 4.82, 4.9, 4.98, 5.08, 5.1, 5.16, 5.26, 5.32, 5.4, 5.52, 5.54, 5.64, 5.7, 5.72, 5.78, 5.82, 5.84, 5.86, 5.92, 5.98, 6.06, 6.12, 6.18, 6.24, 6.34, 6.4, 6.48, 6.56, 6.58, 6.62, 6.66, 6.68, 6.82, 6.84, 6.94, 7.02]
|
||||||
|
|
||||||
|
expected_word_time_stamps_text = ['WHY', 'DOES', 'MILISANDRA', 'LOOK', 'LIKE', 'SHE', 'WANTS', 'TO', 'CONSUME', 'JOHN', 'SNOW', 'ON', 'THE', 'RIVT', 'AP', 'THE', 'WALL']
|
||||||
|
expected_word_time_stamps_start = [1.42, 1.64, 2.26, 3.0, 3.28, 3.62, 3.8, 4.1, 4.28, 4.94, 5.28, 5.68, 5.8, 5.94, 6.32, 6.54, 6.66]
|
||||||
|
expected_word_time_stamps_end = [1.54, 1.9, 2.9, 3.16, 3.52, 3.72, 4.04, 4.18, 4.82, 5.16, 5.54, 5.72, 5.86, 6.18, 6.4, 6.62, 6.94]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
output = tokenizer.batch_decode(pred_ids, output_char_offsets=True, output_word_offsets=True)
|
||||||
|
|
||||||
|
char_offsets_text = self.get_from_offsets(output["char_offsets"][0], "char")
|
||||||
|
char_offsets_start = self.get_from_offsets(output["char_offsets"][0], "start_offset")
|
||||||
|
char_offsets_end = self.get_from_offsets(output["char_offsets"][0], "end_offset")
|
||||||
|
|
||||||
|
word_offsets_text = self.get_from_offsets(output["word_offsets"][0], "word")
|
||||||
|
word_offsets_start = self.get_from_offsets(output["word_offsets"][0], "start_offset")
|
||||||
|
word_offsets_end = self.get_from_offsets(output["word_offsets"][0], "end_offset")
|
||||||
|
|
||||||
|
# let's transform offsets to time stamps in seconds
|
||||||
|
char_time_stamps_start = [round(c * time_offset_wav2vec2_base, 2) for c in char_offsets_start]
|
||||||
|
char_time_stamps_end = [round(c * time_offset_wav2vec2_base, 2) for c in char_offsets_end]
|
||||||
|
|
||||||
|
word_time_stamps_start = [round(w * time_offset_wav2vec2_base, 2) for w in word_offsets_start]
|
||||||
|
word_time_stamps_end = [round(w * time_offset_wav2vec2_base, 2) for w in word_offsets_end]
|
||||||
|
|
||||||
|
# NOTE: you can verify the above results by checking out the dataset viewer
|
||||||
|
# on https://huggingface.co/datasets/common_voice/viewer/en/train and
|
||||||
|
# downloading / playing the sample `common_voice_en_100038.mp3`. As
|
||||||
|
# you can hear the time-stamps match more or less
|
||||||
|
|
||||||
|
self.assertListEqual(expected_char_time_stamps_text, char_offsets_text)
|
||||||
|
self.assertListEqual(expected_char_time_stamps_start, char_time_stamps_start)
|
||||||
|
self.assertListEqual(expected_char_time_stamps_end, char_time_stamps_end)
|
||||||
|
|
||||||
|
self.assertListEqual(expected_word_time_stamps_text, word_offsets_text)
|
||||||
|
self.assertListEqual(expected_word_time_stamps_start, word_time_stamps_start)
|
||||||
|
self.assertListEqual(expected_word_time_stamps_end, word_time_stamps_end)
|
||||||
|
|
||||||
def test_pretrained_model_lists(self):
|
def test_pretrained_model_lists(self):
|
||||||
# Wav2Vec2Model has no max model length => no testing
|
# Wav2Vec2Model has no max model length => no testing
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
from transformers import Wav2Vec2PhonemeCTCTokenizer
|
from transformers import Wav2Vec2PhonemeCTCTokenizer
|
||||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||||
|
from transformers.models.wav2vec2_phoneme.tokenization_wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizerOutput
|
||||||
from transformers.testing_utils import require_phonemizer
|
from transformers.testing_utils import require_phonemizer
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
@@ -248,23 +249,94 @@ class Wav2Vec2PhonemeCTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||||
self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ!?!? $$$", "j ð s j ð s oːɹ $$$"])
|
self.assertEqual(batch_tokens, ["k s ɾ ɾ l ɭʲ!?!? $$$", "j ð s j ð s oːɹ $$$"])
|
||||||
|
|
||||||
# overwrite common test
|
@staticmethod
|
||||||
|
def get_from_offsets(offsets, key):
|
||||||
|
retrieved_list = [d[key] for d in offsets]
|
||||||
|
return retrieved_list
|
||||||
|
|
||||||
|
def test_offsets(self):
|
||||||
|
tokenizer = self.get_tokenizer(word_delimiter_token="|")
|
||||||
|
tokenizer.add_tokens("|")
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# ksssɾɾ|ɾɾ<pad>ɾɾ|<pad>ɾlll|ɭʲ -> k s ɾ ɾ | ɾ l | ɭʲ"
|
||||||
|
sample_ids = [11, 5, 5, 5, 15, 15, tokenizer.pad_token_id, 15, 15, tokenizer.word_delimiter_token_id, tokenizer.pad_token_id, 15, 8, 8, 8, tokenizer.word_delimiter_token_id, 98]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, filter_word_delimiter_token=False)
|
||||||
|
# check Wav2Vec2CTCTokenizerOutput keys for char
|
||||||
|
self.assertTrue(len(outputs.keys()), 2)
|
||||||
|
self.assertTrue("text" in outputs)
|
||||||
|
self.assertTrue("char_offsets" in outputs)
|
||||||
|
self.assertTrue(isinstance(outputs, Wav2Vec2PhonemeCTCTokenizerOutput))
|
||||||
|
|
||||||
|
# check that order of chars is correct and identical for both outputs
|
||||||
|
self.assertEqual(" ".join(self.get_from_offsets(outputs["char_offsets"], "char")), outputs.text)
|
||||||
|
self.assertListEqual(
|
||||||
|
self.get_from_offsets(outputs["char_offsets"], "char"), ["k", "s", "ɾ", "ɾ", "|", "ɾ", "l", "|", "ɭʲ"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that offsets are actually correct for char
|
||||||
|
# 0-1 is 11, 1-4 is 5, 4-6 is first 15, 6-7 is <pad> (thus not shown), 7-9 is second 15, 9-10 is word_delimiter_token,
|
||||||
|
# 10-11 is <pad> (thus not shown), 11-12 is third 15, 12-15 is 8, 15-16 is word_delimiter_token, 16-17 is 98
|
||||||
|
self.assertListEqual(
|
||||||
|
self.get_from_offsets(outputs["char_offsets"], "start_offset"), [0, 1, 4, 7, 9, 11, 12, 15, 16]
|
||||||
|
)
|
||||||
|
self.assertListEqual(
|
||||||
|
self.get_from_offsets(outputs["char_offsets"], "end_offset"), [1, 4, 6, 9, 10, 12, 15, 16, 17]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_offsets_batch(self):
|
||||||
|
tokenizer = self.get_tokenizer(word_delimiter_token="|")
|
||||||
|
|
||||||
|
def check_list_tuples_equal(outputs_batch, outputs_list):
|
||||||
|
self.assertTrue(isinstance(outputs_batch, Wav2Vec2PhonemeCTCTokenizerOutput))
|
||||||
|
self.assertTrue(isinstance(outputs_list[0], Wav2Vec2PhonemeCTCTokenizerOutput))
|
||||||
|
|
||||||
|
# transform list to ModelOutput
|
||||||
|
outputs_batch_2 = Wav2Vec2PhonemeCTCTokenizerOutput(
|
||||||
|
{k: [d[k] for d in outputs_list] for k in outputs_list[0]}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertListEqual(outputs_batch["text"], outputs_batch_2["text"])
|
||||||
|
|
||||||
|
def recursive_check(list_or_dict_1, list_or_dict_2):
|
||||||
|
if isinstance(list_or_dict_1, list):
|
||||||
|
[recursive_check(l1, l2) for l1, l2 in zip(list_or_dict_1, list_or_dict_2)]
|
||||||
|
self.assertEqual(list_or_dict_1, list_or_dict_2)
|
||||||
|
|
||||||
|
if "char_offsets" in outputs_batch:
|
||||||
|
recursive_check(outputs_batch["char_offsets"], outputs_batch_2["char_offsets"])
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
sample_ids = [
|
||||||
|
[11, 5, 15, tokenizer.pad_token_id, 15, 4, 8, 98, 32, 32, 32, 32, 4, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||||
|
[24, 22, 5, tokenizer.word_delimiter_token_id, tokenizer.word_delimiter_token_id, 24, 22, 22, 22, 4, 5, 77, tokenizer.pad_token_id, 22, 22, 4, 34, 34, 34, 34],
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# We assume that `decode` works as expected. All we will check now is
|
||||||
|
# the output type is correct and the output is identical to `decode`
|
||||||
|
|
||||||
|
# char
|
||||||
|
outputs_char_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True)
|
||||||
|
outputs_char = [tokenizer.decode(ids, output_char_offsets=True) for ids in sample_ids]
|
||||||
|
check_list_tuples_equal(outputs_char_batch, outputs_char)
|
||||||
|
|
||||||
|
@unittest.skip("Wav2Vec2PhonemeTokenizer always lower cases letters to correctly map to phonemes")
|
||||||
def test_added_tokens_do_lower_case(self):
|
def test_added_tokens_do_lower_case(self):
|
||||||
# Wav2Vec2PhonemeTokenizer always lower cases letters to correctly map to phonemes
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# overwrite common test
|
@unittest.skip("Wav2Vec2PhonemeTokenizer always puts spaces between phonemes")
|
||||||
def test_encode_decode_with_spaces(self):
|
def test_encode_decode_with_spaces(self):
|
||||||
# Wav2Vec2PhonemeTokenizer always puts spaces between phonemes
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# overwrite common test
|
@unittest.skip("encodes to text to ids, but decodes ids to phonemes -> not possible to have internal consistency")
|
||||||
def test_internal_consistency(self):
|
def test_internal_consistency(self):
|
||||||
# encodes to text to ids, but decodes ids to phonemes -> not possible to have internal consistency
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Wav2Vec2PhonemeModel has no max model length => no testing")
|
||||||
def test_pretrained_model_lists(self):
|
def test_pretrained_model_lists(self):
|
||||||
# Wav2Vec2PhonemeModel has no max model length => no testing
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# overwrite common
|
# overwrite common
|
||||||
|
|||||||
Reference in New Issue
Block a user