[Test refactor 1/5] Per-folder tests reorganization (#15725)
* Per-folder tests reorganization Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
0
tests/wav2vec2/__init__.py
Normal file
0
tests/wav2vec2/__init__.py
Normal file
225
tests/wav2vec2/test_feature_extraction_wav2vec2.py
Normal file
225
tests/wav2vec2/test_feature_extraction_wav2vec2.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
|
||||
from ..test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
values = []
|
||||
for batch_idx in range(shape[0]):
|
||||
values.append([])
|
||||
for _ in range(shape[1]):
|
||||
values[-1].append(rng.random() * scale)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class Wav2Vec2FeatureExtractionTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
min_seq_length=400,
|
||||
max_seq_length=2000,
|
||||
feature_size=1,
|
||||
padding_value=0.0,
|
||||
sampling_rate=16000,
|
||||
return_attention_mask=True,
|
||||
do_normalize=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.min_seq_length = min_seq_length
|
||||
self.max_seq_length = max_seq_length
|
||||
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
|
||||
self.feature_size = feature_size
|
||||
self.padding_value = padding_value
|
||||
self.sampling_rate = sampling_rate
|
||||
self.return_attention_mask = return_attention_mask
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
"feature_size": self.feature_size,
|
||||
"padding_value": self.padding_value,
|
||||
"sampling_rate": self.sampling_rate,
|
||||
"return_attention_mask": self.return_attention_mask,
|
||||
"do_normalize": self.do_normalize,
|
||||
}
|
||||
|
||||
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
|
||||
def _flatten(list_of_lists):
|
||||
return list(itertools.chain(*list_of_lists))
|
||||
|
||||
if equal_length:
|
||||
speech_inputs = floats_list((self.batch_size, self.max_seq_length))
|
||||
else:
|
||||
# make sure that inputs increase in size
|
||||
speech_inputs = [
|
||||
_flatten(floats_list((x, self.feature_size)))
|
||||
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
||||
]
|
||||
|
||||
if numpify:
|
||||
speech_inputs = [np.asarray(x) for x in speech_inputs]
|
||||
|
||||
return speech_inputs
|
||||
|
||||
|
||||
class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
|
||||
feature_extraction_class = Wav2Vec2FeatureExtractor
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = Wav2Vec2FeatureExtractionTester(self)
|
||||
|
||||
def _check_zero_mean_unit_variance(self, input_vector):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
# create three inputs of length 800, 1000, and 1200
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||
|
||||
# Test not batched input
|
||||
encoded_sequences_1 = feat_extract(speech_inputs[0], return_tensors="np").input_values
|
||||
encoded_sequences_2 = feat_extract(np_speech_inputs[0], return_tensors="np").input_values
|
||||
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
|
||||
|
||||
# Test batched
|
||||
encoded_sequences_1 = feat_extract(speech_inputs, return_tensors="np").input_values
|
||||
encoded_sequences_2 = feat_extract(np_speech_inputs, return_tensors="np").input_values
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_np(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
|
||||
paddings = ["longest", "max_length", "do_not_pad"]
|
||||
max_lengths = [None, 1600, None]
|
||||
for max_length, padding in zip(max_lengths, paddings):
|
||||
processed = feat_extract(speech_inputs, padding=padding, max_length=max_length, return_tensors="np")
|
||||
input_values = processed.input_values
|
||||
|
||||
self._check_zero_mean_unit_variance(input_values[0][:800])
|
||||
self.assertTrue(input_values[0][800:].sum() < 1e-6)
|
||||
self._check_zero_mean_unit_variance(input_values[1][:1000])
|
||||
self.assertTrue(input_values[0][1000:].sum() < 1e-6)
|
||||
self._check_zero_mean_unit_variance(input_values[2][:1200])
|
||||
|
||||
def test_zero_mean_unit_variance_normalization(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
lengths = range(800, 1400, 200)
|
||||
speech_inputs = [floats_list((1, x))[0] for x in lengths]
|
||||
|
||||
paddings = ["longest", "max_length", "do_not_pad"]
|
||||
max_lengths = [None, 1600, None]
|
||||
|
||||
for max_length, padding in zip(max_lengths, paddings):
|
||||
processed = feat_extract(speech_inputs, max_length=max_length, padding=padding)
|
||||
input_values = processed.input_values
|
||||
|
||||
self._check_zero_mean_unit_variance(input_values[0][:800])
|
||||
self._check_zero_mean_unit_variance(input_values[1][:1000])
|
||||
self._check_zero_mean_unit_variance(input_values[2][:1200])
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc_np_max_length(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = feat_extract(
|
||||
speech_inputs, truncation=True, max_length=1000, padding="max_length", return_tensors="np"
|
||||
)
|
||||
input_values = processed.input_values
|
||||
|
||||
self._check_zero_mean_unit_variance(input_values[0, :800])
|
||||
self._check_zero_mean_unit_variance(input_values[1])
|
||||
self._check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = feat_extract(
|
||||
speech_inputs, truncation=True, max_length=1000, padding="longest", return_tensors="np"
|
||||
)
|
||||
input_values = processed.input_values
|
||||
|
||||
self._check_zero_mean_unit_variance(input_values[0, :800])
|
||||
self._check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
self._check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
# make sure that if max_length < longest -> then pad to max_length
|
||||
self.assertTrue(input_values.shape == (3, 1000))
|
||||
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = feat_extract(
|
||||
speech_inputs, truncation=True, max_length=2000, padding="longest", return_tensors="np"
|
||||
)
|
||||
input_values = processed.input_values
|
||||
|
||||
self._check_zero_mean_unit_variance(input_values[0, :800])
|
||||
self._check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
self._check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
# make sure that if max_length > longest -> then pad to longest
|
||||
self.assertTrue(input_values.shape == (3, 1200))
|
||||
|
||||
@require_torch
|
||||
def test_double_precision_pad(self):
|
||||
import torch
|
||||
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
np_speech_inputs = np.random.rand(100).astype(np.float64)
|
||||
py_speech_inputs = np_speech_inputs.tolist()
|
||||
|
||||
for inputs in [py_speech_inputs, np_speech_inputs]:
|
||||
np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
|
||||
self.assertTrue(np_processed.input_values.dtype == np.float32)
|
||||
pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
|
||||
self.assertTrue(pt_processed.input_values.dtype == torch.float32)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pretrained_checkpoints_are_set_correctly(self):
|
||||
# this test makes sure that models that are using
|
||||
# group norm don't have their feature extractor return the
|
||||
# attention_mask
|
||||
for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST:
|
||||
config = Wav2Vec2Config.from_pretrained(model_id)
|
||||
feat_extract = Wav2Vec2FeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
# only "layer" feature extraction norm should make use of
|
||||
# attention_mask
|
||||
self.assertEqual(feat_extract.return_attention_mask, config.feat_extract_norm == "layer")
|
||||
482
tests/wav2vec2/test_modeling_flax_wav2vec2.py
Normal file
482
tests/wav2vec2/test_modeling_flax_wav2vec2.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import Wav2Vec2Config, is_flax_available
|
||||
from transformers.testing_utils import (
|
||||
is_librosa_available,
|
||||
is_pyctcdecode_available,
|
||||
require_flax,
|
||||
require_librosa,
|
||||
require_pyctcdecode,
|
||||
require_soundfile,
|
||||
slow,
|
||||
)
|
||||
|
||||
from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
|
||||
FlaxWav2Vec2ForCTC,
|
||||
FlaxWav2Vec2ForPreTraining,
|
||||
FlaxWav2Vec2GumbelVectorQuantizer,
|
||||
FlaxWav2Vec2Model,
|
||||
_compute_mask_indices,
|
||||
_sample_negative_indices,
|
||||
)
|
||||
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from transformers import Wav2Vec2ProcessorWithLM
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
|
||||
|
||||
class FlaxWav2Vec2ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=1024, # speech is longer
|
||||
is_training=False,
|
||||
hidden_size=24,
|
||||
feat_extract_norm="layer",
|
||||
feat_extract_dropout=0.0,
|
||||
feat_extract_activation="gelu",
|
||||
conv_dim=(32, 32, 32),
|
||||
conv_stride=(4, 4, 4),
|
||||
conv_kernel=(8, 8, 8),
|
||||
conv_bias=False,
|
||||
num_conv_pos_embeddings=16,
|
||||
num_conv_pos_embedding_groups=2,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=2,
|
||||
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
|
||||
intermediate_size=20,
|
||||
layer_norm_eps=1e-5,
|
||||
hidden_act="gelu",
|
||||
initializer_range=0.02,
|
||||
vocab_size=32,
|
||||
do_stable_layer_norm=True,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.feat_extract_norm = feat_extract_norm
|
||||
self.feat_extract_dropout = feat_extract_dropout
|
||||
self.feat_extract_activation = feat_extract_activation
|
||||
self.conv_dim = conv_dim
|
||||
self.conv_stride = conv_stride
|
||||
self.conv_kernel = conv_kernel
|
||||
self.conv_bias = conv_bias
|
||||
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
||||
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.intermediate_size = intermediate_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.scope = scope
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
||||
output_seq_length = (output_seq_length - (kernel - 1)) / stride
|
||||
self.output_seq_length = int(math.ceil(output_seq_length))
|
||||
self.encoder_seq_length = self.output_seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
config = Wav2Vec2Config(
|
||||
do_stable_layer_norm=self.do_stable_layer_norm,
|
||||
hidden_size=self.hidden_size,
|
||||
feat_extract_norm=self.feat_extract_norm,
|
||||
feat_extract_dropout=self.feat_extract_dropout,
|
||||
feat_extract_activation=self.feat_extract_activation,
|
||||
conv_dim=self.conv_dim,
|
||||
conv_stride=self.conv_stride,
|
||||
conv_kernel=self.conv_kernel,
|
||||
conv_bias=self.conv_bias,
|
||||
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
||||
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
intermediate_size=self.intermediate_size,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
hidden_act=self.hidden_act,
|
||||
initializer_range=self.initializer_range,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
return config, input_values, attention_mask
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_values, attention_mask = config_and_inputs
|
||||
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(FlaxWav2Vec2Model, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForPreTraining) if is_flax_available() else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxWav2Vec2ModelTester(self)
|
||||
|
||||
def test_train(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
|
||||
model = FlaxWav2Vec2ForPreTraining(config)
|
||||
|
||||
features_shape = (
|
||||
input_values.shape[0],
|
||||
model._get_feat_extract_output_lengths(np.array(input_values.shape[1])),
|
||||
)
|
||||
|
||||
batch_size, sequence_length = features_shape[:2]
|
||||
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
|
||||
dropout_rng, gumbel_rng = jax.random.split(jax.random.PRNGKey(0))
|
||||
|
||||
output = model(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
mask_time_indices=mask_time_indices,
|
||||
train=True,
|
||||
dropout_rng=dropout_rng,
|
||||
gumbel_rng=gumbel_rng,
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output.shape == (batch_size, sequence_length, model.config.proj_codevector_dim))
|
||||
|
||||
# overwrite because of `input_values`
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.__call__)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["input_values", "attention_mask"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
# overwrite because of `input_values`
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_values, attention_mask=None, **kwargs):
|
||||
return model(input_values=input_values, attention_mask=attention_mask, **kwargs)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
|
||||
outputs = model(np.ones((1, 1024), dtype="f4"))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
||||
def test_compute_mask_indices(self):
|
||||
batch_size = 4
|
||||
sequence_length = 60
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
def test_compute_mask_indices_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 80
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
|
||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
def test_compute_mask_indices_attn_mask_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 80
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int32)
|
||||
attention_mask[:2, sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
|
||||
|
||||
def test_compute_perplexity(self):
|
||||
probs = np.arange(100).reshape(2, 5, 10) / 100
|
||||
|
||||
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
|
||||
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
|
||||
|
||||
# mask half of the input
|
||||
mask = np.ones((2,), dtype=np.bool)
|
||||
mask[0] = 0
|
||||
|
||||
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
|
||||
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
|
||||
|
||||
def test_sample_negatives(self):
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
|
||||
|
||||
negative_indices = _sample_negative_indices(features.shape, num_negatives)
|
||||
|
||||
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
|
||||
# take negative vectors from sampled indices
|
||||
sampled_negatives = features[negative_indices.reshape(-1)]
|
||||
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
|
||||
2, 0, 1, 3
|
||||
)
|
||||
|
||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||
|
||||
# make sure no negatively sampled vector is actually a positive one
|
||||
for negative in negatives:
|
||||
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
|
||||
|
||||
# make sure that full vectors are sampled and not values of vectors
|
||||
# => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
def test_sample_negatives_with_attn_mask(self):
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
|
||||
# second half of last input tensor is padded
|
||||
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int8)
|
||||
attention_mask[-1, sequence_length // 2 :] = 0
|
||||
|
||||
forbidden_indices = (
|
||||
np.arange(sequence_length // 2, sequence_length, dtype=np.int32) + (batch_size - 1) * sequence_length
|
||||
).tolist()
|
||||
|
||||
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
|
||||
|
||||
negative_indices = _sample_negative_indices(features.shape, num_negatives, attention_mask=attention_mask)
|
||||
|
||||
# make sure that no padding tokens are sampled
|
||||
self.assertTrue(all([idx not in negative_indices for idx in forbidden_indices]))
|
||||
|
||||
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
|
||||
# take negative vectors from sampled indices
|
||||
sampled_negatives = features[negative_indices.reshape(-1)]
|
||||
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
|
||||
2, 0, 1, 3
|
||||
)
|
||||
|
||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||
|
||||
# make sure no negatively sampled vector is actually a positive one
|
||||
for negative in negatives:
|
||||
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
|
||||
|
||||
# make sure that full vectors are sampled and not just slices of vectors
|
||||
# => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
|
||||
@require_flax
|
||||
@require_soundfile
|
||||
@slow
|
||||
class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
def _load_datasamples(self, num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").filter(
|
||||
lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)]
|
||||
)[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def test_inference_ctc_robust_batched(self):
|
||||
model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
|
||||
input_values = inputs.input_values
|
||||
attention_mask = inputs.attention_mask
|
||||
|
||||
logits = model(input_values, attention_mask=attention_mask).logits
|
||||
|
||||
predicted_ids = jnp.argmax(logits, axis=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
|
||||
"his instant panic was followed by a small sharp blow high on his chest",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_pretrained(self):
|
||||
model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"facebook/wav2vec2-large-lv60", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True)
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])),
|
||||
)
|
||||
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
min_masks=2,
|
||||
)
|
||||
|
||||
outputs = model(
|
||||
inputs_dict.input_values,
|
||||
attention_mask=inputs_dict.attention_mask,
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# compute cosine similarity
|
||||
cosine_sim = optax.cosine_similarity(
|
||||
outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8
|
||||
)
|
||||
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked = cosine_sim[mask_time_indices]
|
||||
|
||||
# ... now compare to randomly initialized model
|
||||
|
||||
config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||
model_rand = FlaxWav2Vec2ForPreTraining(config)
|
||||
|
||||
outputs_rand = model_rand(
|
||||
inputs_dict.input_values,
|
||||
attention_mask=inputs_dict.attention_mask,
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# compute cosine similarity
|
||||
cosine_sim_rand = optax.cosine_similarity(
|
||||
outputs_rand.projected_states, outputs_rand.projected_quantized_states
|
||||
)
|
||||
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
|
||||
|
||||
# a pretrained wav2vec2 model has learned to predict the quantized latent states
|
||||
# => the cosine similarity between quantized states and predicted states > 0.5
|
||||
# a random wav2vec2 model has not learned to predict the quantized latent states
|
||||
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
|
||||
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
||||
|
||||
model = FlaxWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(resampled_audio, return_tensors="np").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
transcription = processor.batch_decode(np.array(logits)).text
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
572
tests/wav2vec2/test_modeling_tf_wav2vec2.py
Normal file
572
tests/wav2vec2/test_modeling_tf_wav2vec2.py
Normal file
@@ -0,0 +1,572 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import glob
|
||||
import inspect
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import Wav2Vec2Config, is_tf_available
|
||||
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
|
||||
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFWav2Vec2ForCTC, TFWav2Vec2Model, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices
|
||||
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from transformers import Wav2Vec2ProcessorWithLM
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=1024,
|
||||
is_training=False,
|
||||
hidden_size=16,
|
||||
feat_extract_norm="group",
|
||||
feat_extract_dropout=0.0,
|
||||
feat_extract_activation="gelu",
|
||||
conv_dim=(32, 32, 32),
|
||||
conv_stride=(4, 4, 4),
|
||||
conv_kernel=(8, 8, 8),
|
||||
conv_bias=False,
|
||||
num_conv_pos_embeddings=16,
|
||||
num_conv_pos_embedding_groups=2,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=2,
|
||||
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
|
||||
intermediate_size=20,
|
||||
layer_norm_eps=1e-5,
|
||||
hidden_act="gelu",
|
||||
initializer_range=0.02,
|
||||
vocab_size=32,
|
||||
do_stable_layer_norm=False,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.feat_extract_norm = feat_extract_norm
|
||||
self.feat_extract_dropout = feat_extract_dropout
|
||||
self.feat_extract_activation = feat_extract_activation
|
||||
self.conv_dim = conv_dim
|
||||
self.conv_stride = conv_stride
|
||||
self.conv_kernel = conv_kernel
|
||||
self.conv_bias = conv_bias
|
||||
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
||||
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.intermediate_size = intermediate_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.scope = scope
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
||||
output_seq_length = (output_seq_length - (kernel - 1)) / stride
|
||||
self.output_seq_length = int(math.ceil(output_seq_length))
|
||||
self.encoder_seq_length = self.output_seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = tf.cast(ids_tensor([self.batch_size, self.seq_length], 32768), tf.float32) / 32768.0
|
||||
attention_mask = tf.ones_like(input_values)
|
||||
|
||||
config = Wav2Vec2Config(
|
||||
hidden_size=self.hidden_size,
|
||||
feat_extract_norm=self.feat_extract_norm,
|
||||
feat_extract_dropout=self.feat_extract_dropout,
|
||||
feat_extract_activation=self.feat_extract_activation,
|
||||
conv_dim=self.conv_dim,
|
||||
conv_stride=self.conv_stride,
|
||||
conv_kernel=self.conv_kernel,
|
||||
conv_bias=self.conv_bias,
|
||||
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
||||
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
intermediate_size=self.intermediate_size,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
hidden_act=self.hidden_act,
|
||||
initializer_range=self.initializer_range,
|
||||
vocab_size=self.vocab_size,
|
||||
do_stable_layer_norm=self.do_stable_layer_norm,
|
||||
)
|
||||
|
||||
return config, input_values, attention_mask
|
||||
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
model = TFWav2Vec2Model(config)
|
||||
result = model(input_values, attention_mask=attention_mask)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_batch_inference(self, config, input_values, *args):
|
||||
# test does not pass for models making use of `group_norm`
|
||||
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||
config.layerdrop = 0.0
|
||||
model = TFWav2Vec2Model(config)
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = tf.ones_like(input_values)
|
||||
|
||||
input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]])
|
||||
length_mask = tf.sequence_mask(input_lengths, dtype=tf.float32)
|
||||
|
||||
# convert values that are over input_lengths to padding
|
||||
input_values = input_values * length_mask
|
||||
attention_mask = attention_mask * length_mask
|
||||
|
||||
batch_outputs = model(input_values, attention_mask=attention_mask, training=False).last_hidden_state
|
||||
|
||||
for i in range(input_values.shape[0]):
|
||||
input_slice = input_values[i : i + 1, : input_lengths[i]]
|
||||
output = model(input_slice, training=False).last_hidden_state
|
||||
|
||||
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
|
||||
self.parent.assertTrue(np.allclose(output, batch_output, atol=1e-3))
|
||||
|
||||
def check_ctc_loss(self, config, input_values, *args):
|
||||
model = TFWav2Vec2ForCTC(config)
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = tf.ones_like(input_values)
|
||||
|
||||
input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]])
|
||||
max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(input_lengths)
|
||||
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
|
||||
|
||||
length_mask = tf.sequence_mask(input_lengths, dtype=tf.float32)
|
||||
|
||||
# convert values that are over input_lengths to padding
|
||||
input_values = input_values * length_mask
|
||||
attention_mask = attention_mask * length_mask
|
||||
|
||||
model.config.ctc_loss_reduction = "sum"
|
||||
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
|
||||
|
||||
model.config.ctc_loss_reduction = "mean"
|
||||
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
|
||||
|
||||
self.parent.assertTrue(abs(labels.shape[0] * mean_loss - sum_loss) < 1e-2)
|
||||
|
||||
def check_training(self, config, input_values, *args):
|
||||
model = TFWav2Vec2ForCTC(config)
|
||||
|
||||
# freeze feature encoder
|
||||
model.freeze_feature_encoder()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]])
|
||||
max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(input_lengths)
|
||||
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
|
||||
|
||||
length_mask = tf.sequence_mask(input_lengths, dtype=tf.float32)
|
||||
|
||||
input_values = input_values * length_mask
|
||||
|
||||
pad_size = max(max_length_labels) - labels.shape[1]
|
||||
labels = tf.pad(labels, ((0, 0), (0, pad_size)), constant_values=-100)
|
||||
|
||||
loss = model(input_values, labels=labels, training=True).loss
|
||||
|
||||
self.parent.assertFalse(tf.math.is_inf(loss))
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = TFWav2Vec2ForCTC(config)
|
||||
input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]])
|
||||
max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(input_lengths)
|
||||
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size + 100)
|
||||
with pytest.raises(ValueError):
|
||||
model(input_values, labels=labels)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else ()
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFWav2Vec2ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Wav2Vec2Config, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# overwrite because input_values != input_ids
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["input_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
# overwrite because input_values != input_ids
|
||||
def test_keyword_and_dict_args(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
outputs_dict = model(inputs)
|
||||
|
||||
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
input_values = inputs_keywords.pop("input_values", None)
|
||||
outputs_keywords = model(input_values, **inputs_keywords)
|
||||
output_dict = outputs_dict[0].numpy()
|
||||
output_keywords = outputs_keywords[0].numpy()
|
||||
|
||||
self.assertLess(np.sum(np.abs(output_dict - output_keywords)), 1e-6)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_hidden_states_output(config, inputs_dict, model_class):
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
self.assertEqual(config.output_attentions, False)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.output_seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(config, inputs_dict, model_class)
|
||||
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
check_hidden_states_output(config, inputs_dict, model_class)
|
||||
|
||||
def test_ctc_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 cannot resize token embeddings
|
||||
# since it has no tokens embeddings
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
# and thus the `get_input_embeddings` fn
|
||||
# is not implemented
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else ()
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFWav2Vec2ModelTester(
|
||||
self,
|
||||
conv_stride=(3, 3, 3),
|
||||
feat_extract_norm="layer",
|
||||
do_stable_layer_norm=True,
|
||||
scope="robust",
|
||||
)
|
||||
self.config_tester = ConfigTester(self, config_class=Wav2Vec2Config, hidden_size=37)
|
||||
|
||||
# overwrite because input_values != input_ids
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["input_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
# overwrite because input_values != input_ids
|
||||
def test_keyword_and_dict_args(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
outputs_dict = model(inputs)
|
||||
|
||||
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
input_values = inputs_keywords.pop("input_values", None)
|
||||
outputs_keywords = model(input_values, **inputs_keywords)
|
||||
output_dict = outputs_dict[0].numpy()
|
||||
output_keywords = outputs_keywords[0].numpy()
|
||||
|
||||
self.assertLess(np.sum(np.abs(output_dict - output_keywords)), 1e-6)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_hidden_states_output(config, inputs_dict, model_class):
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
self.assertEqual(config.output_attentions, False)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.output_seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(config, inputs_dict, model_class)
|
||||
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
check_hidden_states_output(config, inputs_dict, model_class)
|
||||
|
||||
def test_batched_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
|
||||
|
||||
def test_ctc_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 cannot resize token embeddings
|
||||
# since it has no tokens embeddings
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
# and thus the `get_input_embeddings` fn
|
||||
# is not implemented
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2UtilsTest(unittest.TestCase):
|
||||
def test_compute_mask_indices(self):
|
||||
batch_size = 4
|
||||
sequence_length = 60
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
|
||||
self.assertListEqual(
|
||||
tf.reduce_sum(mask, -1).numpy().tolist(), [mask_prob * sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
||||
def test_compute_mask_indices_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 80
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
|
||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||
for batch_sum in tf.reduce_sum(mask, -1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
|
||||
@require_tf
|
||||
@slow
|
||||
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
def _load_datasamples(self, num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").filter(
|
||||
lambda x: x["id"] in [f"1272-141231-000{i}" for i in range(num_samples)]
|
||||
)[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def test_inference_ctc_normal(self):
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
input_values = processor(input_speech, return_tensors="tf", sampling_rate=16000).input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
predicted_ids = tf.argmax(logits, axis=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_ctc_normal_batched(self):
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
input_values = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000).input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
predicted_ids = tf.argmax(logits, axis=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_ctc_robust_batched(self):
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="tf", padding=True, sampling_rate=16000)
|
||||
|
||||
input_values = inputs.input_values
|
||||
attention_mask = inputs.attention_mask
|
||||
|
||||
logits = model(input_values, attention_mask=attention_mask).logits
|
||||
|
||||
predicted_ids = tf.argmax(logits, axis=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
|
||||
"his instant panic was followed by a small sharp blow high on his chest",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm(self):
|
||||
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||
sample = librosa.load(file_path, sr=16_000)[0]
|
||||
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(sample, return_tensors="tf").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
1554
tests/wav2vec2/test_modeling_wav2vec2.py
Normal file
1554
tests/wav2vec2/test_modeling_wav2vec2.py
Normal file
File diff suppressed because it is too large
Load Diff
140
tests/wav2vec2/test_processor_wav2vec2.py
Normal file
140
tests/wav2vec2/test_processor_wav2vec2.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
|
||||
from .test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
|
||||
class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
|
||||
self.add_kwargs_tokens_map = {
|
||||
"pad_token": "<pad>",
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
}
|
||||
feature_extractor_map = {
|
||||
"feature_size": 1,
|
||||
"padding_value": 0.0,
|
||||
"sampling_rate": 16000,
|
||||
"return_attention_mask": False,
|
||||
"do_normalize": True,
|
||||
}
|
||||
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
|
||||
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(feature_extractor_map) + "\n")
|
||||
|
||||
def get_tokenizer(self, **kwargs_init):
|
||||
kwargs = self.add_kwargs_tokens_map.copy()
|
||||
kwargs.update(kwargs_init)
|
||||
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return Wav2Vec2FeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = Wav2Vec2Processor.from_pretrained(self.tmpdirname)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = Wav2Vec2Processor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
|
||||
|
||||
processor = Wav2Vec2Processor.from_pretrained(
|
||||
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_feature_extractor(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
raw_speech = floats_list((3, 1000))
|
||||
|
||||
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
|
||||
input_processor = processor(raw_speech, return_tensors="np")
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
def test_tokenizer(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "This is a test string"
|
||||
|
||||
with processor.as_target_processor():
|
||||
encoded_processor = processor(input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||
|
||||
decoded_processor = processor.batch_decode(predicted_ids)
|
||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||
|
||||
self.assertListEqual(decoded_tok, decoded_processor)
|
||||
719
tests/wav2vec2/test_tokenization_wav2vec2.py
Normal file
719
tests/wav2vec2/test_tokenization_wav2vec2.py
Normal file
@@ -0,0 +1,719 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the Wav2Vec2 tokenizer."""
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import (
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2CTCTokenizer,
|
||||
Wav2Vec2Tokenizer,
|
||||
)
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2CTCTokenizerOutput
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
|
||||
from ..test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
values = []
|
||||
for batch_idx in range(shape[0]):
|
||||
values.append([])
|
||||
for _ in range(shape[1]):
|
||||
values[-1].append(rng.random() * scale)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class Wav2Vec2TokenizerTest(unittest.TestCase):
|
||||
tokenizer_class = Wav2Vec2Tokenizer
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
|
||||
self.special_tokens_map = {"pad_token": "<pad>", "unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return Wav2Vec2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
# TODO(PVP) - change to facebook
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||
]
|
||||
tokens = tokenizer.decode(sample_ids[0])
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
self.assertEqual(tokens, batch_tokens[0])
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
|
||||
|
||||
def test_tokenizer_decode_special(self):
|
||||
# TODO(PVP) - change to facebook
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||
]
|
||||
sample_ids_2 = [
|
||||
[11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[
|
||||
24,
|
||||
22,
|
||||
5,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
24,
|
||||
22,
|
||||
5,
|
||||
77,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
],
|
||||
]
|
||||
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
|
||||
self.assertEqual(batch_tokens, batch_tokens_2)
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
|
||||
|
||||
def test_tokenizer_decode_added_tokens(self):
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
tokenizer.add_tokens(["!", "?"])
|
||||
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
||||
|
||||
sample_ids = [
|
||||
[
|
||||
11,
|
||||
5,
|
||||
15,
|
||||
tokenizer.pad_token_id,
|
||||
15,
|
||||
8,
|
||||
98,
|
||||
32,
|
||||
32,
|
||||
33,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
32,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
34,
|
||||
],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
||||
]
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
tokenizer = self.get_tokenizer()
|
||||
# create three inputs of length 800, 1000, and 1200
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||
|
||||
# Test not batched input
|
||||
encoded_sequences_1 = tokenizer(speech_inputs[0], return_tensors="np").input_values
|
||||
encoded_sequences_2 = tokenizer(np_speech_inputs[0], return_tensors="np").input_values
|
||||
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
|
||||
|
||||
# Test batched
|
||||
encoded_sequences_1 = tokenizer(speech_inputs, return_tensors="np").input_values
|
||||
encoded_sequences_2 = tokenizer(np_speech_inputs, return_tensors="np").input_values
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_padding(self, max_length=50):
|
||||
def _input_values_have_equal_length(input_values):
|
||||
length = len(input_values[0])
|
||||
for input_values_slice in input_values[1:]:
|
||||
if len(input_values_slice) != length:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _input_values_are_equal(input_values_1, input_values_2):
|
||||
if len(input_values_1) != len(input_values_2):
|
||||
return False
|
||||
|
||||
for input_values_slice_1, input_values_slice_2 in zip(input_values_1, input_values_2):
|
||||
if not np.allclose(np.asarray(input_values_slice_1), np.asarray(input_values_slice_2), atol=1e-3):
|
||||
return False
|
||||
return True
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
|
||||
input_values_1 = tokenizer(speech_inputs).input_values
|
||||
input_values_2 = tokenizer(speech_inputs, padding="longest").input_values
|
||||
input_values_3 = tokenizer(speech_inputs, padding="longest", max_length=1600).input_values
|
||||
|
||||
self.assertFalse(_input_values_have_equal_length(input_values_1))
|
||||
self.assertTrue(_input_values_have_equal_length(input_values_2))
|
||||
self.assertTrue(_input_values_have_equal_length(input_values_3))
|
||||
self.assertTrue(_input_values_are_equal(input_values_2, input_values_3))
|
||||
self.assertTrue(len(input_values_1[0]) == 800)
|
||||
self.assertTrue(len(input_values_2[0]) == 1200)
|
||||
# padding should be 0.0
|
||||
self.assertTrue(abs(sum(np.asarray(input_values_2[0])[800:])) < 1e-3)
|
||||
self.assertTrue(abs(sum(np.asarray(input_values_2[1])[1000:])) < 1e-3)
|
||||
|
||||
input_values_4 = tokenizer(speech_inputs, padding="max_length").input_values
|
||||
input_values_5 = tokenizer(speech_inputs, padding="max_length", max_length=1600).input_values
|
||||
|
||||
self.assertTrue(_input_values_are_equal(input_values_1, input_values_4))
|
||||
self.assertTrue(input_values_5.shape, (3, 1600))
|
||||
# padding should be 0.0
|
||||
self.assertTrue(abs(sum(np.asarray(input_values_5[0])[800:1200])) < 1e-3)
|
||||
|
||||
input_values_6 = tokenizer(speech_inputs, pad_to_multiple_of=500).input_values
|
||||
input_values_7 = tokenizer(speech_inputs, padding="longest", pad_to_multiple_of=500).input_values
|
||||
input_values_8 = tokenizer(
|
||||
speech_inputs, padding="max_length", pad_to_multiple_of=500, max_length=2400
|
||||
).input_values
|
||||
|
||||
self.assertTrue(_input_values_are_equal(input_values_1, input_values_6))
|
||||
self.assertTrue(input_values_7.shape, (3, 1500))
|
||||
self.assertTrue(input_values_8.shape, (3, 2500))
|
||||
# padding should be 0.0
|
||||
self.assertTrue(abs(sum(np.asarray(input_values_7[0])[800:])) < 1e-3)
|
||||
self.assertTrue(abs(sum(np.asarray(input_values_7[1])[1000:])) < 1e-3)
|
||||
self.assertTrue(abs(sum(np.asarray(input_values_7[2])[1200:])) < 1e-3)
|
||||
self.assertTrue(abs(sum(np.asarray(input_values_8[0])[800:])) < 1e-3)
|
||||
self.assertTrue(abs(sum(np.asarray(input_values_8[1])[1000:])) < 1e-3)
|
||||
self.assertTrue(abs(sum(np.asarray(input_values_8[2])[1200:])) < 1e-3)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
pretrained_name = list(self.tokenizer_class.pretrained_vocab_files_map["vocab_file"].keys())[0]
|
||||
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name)
|
||||
tmpdirname2 = tempfile.mkdtemp()
|
||||
|
||||
tokenizer_files = tokenizer.save_pretrained(tmpdirname2)
|
||||
self.assertSequenceEqual(
|
||||
sorted(tuple(VOCAB_FILES_NAMES.values()) + ("special_tokens_map.json", "added_tokens.json")),
|
||||
sorted(tuple(x.split(os.path.sep)[-1] for x in tokenizer_files)),
|
||||
)
|
||||
|
||||
# Checks everything loads correctly in the same way
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(tmpdirname2)
|
||||
|
||||
# Check special tokens are set accordingly on Rust and Python
|
||||
for key in tokenizer.special_tokens_map:
|
||||
self.assertTrue(key in tokenizer_p.special_tokens_map)
|
||||
|
||||
shutil.rmtree(tmpdirname2)
|
||||
|
||||
def test_get_vocab(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
vocab_dict = tokenizer.get_vocab()
|
||||
self.assertIsInstance(vocab_dict, dict)
|
||||
self.assertGreaterEqual(len(tokenizer), len(vocab_dict))
|
||||
|
||||
vocab = [tokenizer.convert_ids_to_tokens(i) for i in range(len(tokenizer))]
|
||||
self.assertEqual(len(vocab), len(tokenizer))
|
||||
|
||||
tokenizer.add_tokens(["asdfasdfasdfasdf"])
|
||||
vocab = [tokenizer.convert_ids_to_tokens(i) for i in range(len(tokenizer))]
|
||||
self.assertEqual(len(vocab), len(tokenizer))
|
||||
|
||||
def test_save_and_load_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
# Isolate this from the other tests because we save additional tokens/etc
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
sample_ids = [0, 1, 4, 8, 9, 0, 12]
|
||||
before_tokens = tokenizer.decode(sample_ids)
|
||||
before_vocab = tokenizer.get_vocab()
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
||||
after_tokens = after_tokenizer.decode(sample_ids)
|
||||
after_vocab = after_tokenizer.get_vocab()
|
||||
|
||||
self.assertEqual(before_tokens, after_tokens)
|
||||
self.assertDictEqual(before_vocab, after_vocab)
|
||||
|
||||
shutil.rmtree(tmpdirname)
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# Isolate this from the other tests because we save additional tokens/etc
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
before_len = len(tokenizer)
|
||||
sample_ids = [0, 1, 4, 8, 9, 0, 12, before_len, before_len + 1, before_len + 2]
|
||||
tokenizer.add_tokens(["?", "!"])
|
||||
additional_special_tokens = tokenizer.additional_special_tokens
|
||||
additional_special_tokens.append("&")
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
before_tokens = tokenizer.decode(sample_ids)
|
||||
before_vocab = tokenizer.get_vocab()
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
||||
after_tokens = after_tokenizer.decode(sample_ids)
|
||||
after_vocab = after_tokenizer.get_vocab()
|
||||
|
||||
self.assertEqual(before_tokens, after_tokens)
|
||||
self.assertDictEqual(before_vocab, after_vocab)
|
||||
|
||||
self.assertTrue(len(tokenizer), before_len + 3)
|
||||
self.assertTrue(len(tokenizer), len(after_tokenizer))
|
||||
shutil.rmtree(tmpdirname)
|
||||
|
||||
def test_tokenizer_slow_store_full_signature(self):
|
||||
signature = inspect.signature(self.tokenizer_class.__init__)
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
for parameter_name, parameter in signature.parameters.items():
|
||||
if parameter.default != inspect.Parameter.empty:
|
||||
self.assertIn(parameter_name, tokenizer.init_kwargs)
|
||||
|
||||
def test_zero_mean_unit_variance_normalization(self):
|
||||
tokenizer = self.get_tokenizer(do_normalize=True)
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = tokenizer(speech_inputs, padding="longest")
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0, :800])
|
||||
_check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
_check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
def test_return_attention_mask(self):
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
|
||||
# default case -> no attention_mask is returned
|
||||
tokenizer = self.get_tokenizer()
|
||||
processed = tokenizer(speech_inputs)
|
||||
self.assertNotIn("attention_mask", processed)
|
||||
|
||||
# wav2vec2-lv60 -> return attention_mask
|
||||
tokenizer = self.get_tokenizer(return_attention_mask=True)
|
||||
processed = tokenizer(speech_inputs, padding="longest")
|
||||
|
||||
self.assertIn("attention_mask", processed)
|
||||
self.assertListEqual(list(processed.attention_mask.shape), list(processed.input_values.shape))
|
||||
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), [800, 1000, 1200])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pretrained_checkpoints_are_set_correctly(self):
|
||||
# this test makes sure that models that are using
|
||||
# group norm don't have their tokenizer return the
|
||||
# attention_mask
|
||||
for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST:
|
||||
config = Wav2Vec2Config.from_pretrained(model_id)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_id)
|
||||
|
||||
# only "layer" feature extraction norm should make use of
|
||||
# attention_mask
|
||||
self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")
|
||||
|
||||
|
||||
class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = Wav2Vec2CTCTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
|
||||
self.special_tokens_map = {"pad_token": "<pad>", "unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def test_tokenizer_add_token_chars(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
# check adding a single token
|
||||
tokenizer.add_tokens("x")
|
||||
token_ids = tokenizer("C x A").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 32, 4, 7])
|
||||
|
||||
tokenizer.add_tokens(["a", "b", "c"])
|
||||
token_ids = tokenizer("C a A c").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35])
|
||||
|
||||
tokenizer.add_tokens(["a", "b", "c"])
|
||||
token_ids = tokenizer("CaA c").input_ids
|
||||
self.assertEqual(token_ids, [19, 33, 7, 4, 35])
|
||||
|
||||
def test_tokenizer_add_token_words(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
# check adding a single token
|
||||
tokenizer.add_tokens("xxx")
|
||||
token_ids = tokenizer("C xxx A B").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 32, 4, 7, 4, 24])
|
||||
|
||||
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
|
||||
token_ids = tokenizer("C aaa A ccc B B").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35, 4, 24, 4, 24])
|
||||
|
||||
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
|
||||
token_ids = tokenizer("CaaaA ccc B B").input_ids
|
||||
self.assertEqual(token_ids, [19, 33, 7, 4, 35, 4, 24, 4, 24])
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||
]
|
||||
tokens = tokenizer.decode(sample_ids[0])
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
self.assertEqual(tokens, batch_tokens[0])
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
|
||||
|
||||
def test_tokenizer_decode_special(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||
]
|
||||
sample_ids_2 = [
|
||||
[11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.pad_token_id, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.word_delimiter_token_id],
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
|
||||
self.assertEqual(batch_tokens, batch_tokens_2)
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
|
||||
|
||||
def test_tokenizer_decode_added_tokens(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
tokenizer.add_tokens(["!", "?"])
|
||||
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
||||
]
|
||||
# fmt: on
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||
|
||||
def test_special_characters_in_vocab(self):
|
||||
sent = "ʈʰ æ æ̃ ˧ kʰ"
|
||||
|
||||
vocab_dict = {k: v for v, k in enumerate({phoneme for phoneme in sent.split()})}
|
||||
vocab_file = os.path.join(self.tmpdirname, "vocab_special.json")
|
||||
|
||||
with open(vocab_file, "w") as f:
|
||||
json.dump(vocab_dict, f)
|
||||
|
||||
tokenizer = Wav2Vec2CTCTokenizer(vocab_file)
|
||||
|
||||
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||
self.assertEqual(sent, expected_sent)
|
||||
|
||||
tokenizer.save_pretrained(os.path.join(self.tmpdirname, "special_tokenizer"))
|
||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(os.path.join(self.tmpdirname, "special_tokenizer"))
|
||||
|
||||
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||
self.assertEqual(sent, expected_sent)
|
||||
|
||||
@staticmethod
|
||||
def get_from_offsets(offsets, key):
|
||||
retrieved_list = [d[key] for d in offsets]
|
||||
return retrieved_list
|
||||
|
||||
def test_offsets(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# fmt: off
|
||||
# HEEEEE||LLL<pad>LO<unk> => HE LLO<unk>
|
||||
# 1H + 5E + 2| + 3L + 1<pad> + 1L + 1O + 1<unk>
|
||||
sample_ids = [11, 5, 5, 5, 5, 5, 4, 4, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98]
|
||||
# fmt: on
|
||||
|
||||
outputs_char = tokenizer.decode(sample_ids, output_char_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for char
|
||||
self.assertTrue(len(outputs_char.keys()), 2)
|
||||
self.assertTrue("text" in outputs_char)
|
||||
self.assertTrue("char_offsets" in outputs_char)
|
||||
self.assertTrue(isinstance(outputs_char, Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
outputs_word = tokenizer.decode(sample_ids, output_word_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for word
|
||||
self.assertTrue(len(outputs_word.keys()), 2)
|
||||
self.assertTrue("text" in outputs_word)
|
||||
self.assertTrue("word_offsets" in outputs_word)
|
||||
self.assertTrue(isinstance(outputs_word, Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
outputs = tokenizer.decode(sample_ids, output_char_offsets=True, output_word_offsets=True)
|
||||
# check Wav2Vec2CTCTokenizerOutput keys for both
|
||||
self.assertTrue(len(outputs.keys()), 3)
|
||||
self.assertTrue("text" in outputs)
|
||||
self.assertTrue("char_offsets" in outputs)
|
||||
self.assertTrue("word_offsets" in outputs)
|
||||
self.assertTrue(isinstance(outputs, Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
# check that order of chars is correct and identical for both outputs
|
||||
self.assertEqual("".join(self.get_from_offsets(outputs["char_offsets"], "char")), outputs.text)
|
||||
self.assertEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "char"), ["H", "E", " ", "L", "L", "O", "<unk>"]
|
||||
)
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["char_offsets"], "char"),
|
||||
self.get_from_offsets(outputs_char["char_offsets"], "char"),
|
||||
)
|
||||
|
||||
# check that order of words is correct and identical to both outputs
|
||||
self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text)
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["HE", "LLO<unk>"])
|
||||
self.assertListEqual(
|
||||
self.get_from_offsets(outputs["word_offsets"], "word"),
|
||||
self.get_from_offsets(outputs_word["word_offsets"], "word"),
|
||||
)
|
||||
|
||||
# check that offsets are actually correct for char
|
||||
# 0 is H, 1 is E, 6 is | (" "), 8 is 1st L, 12 is 2nd L, 13 is O, 14 is <unk>
|
||||
self.assertListEqual(self.get_from_offsets(outputs["char_offsets"], "start_offset"), [0, 1, 6, 8, 12, 13, 14])
|
||||
# 1 is H, 6 is E, 8 is | (" "), 11 is 1st L (note due to <pad>
|
||||
# different begin of 2nd L), 13 is 2nd L, 14 is O, 15 is <unk>
|
||||
self.assertListEqual(self.get_from_offsets(outputs["char_offsets"], "end_offset"), [1, 6, 8, 11, 13, 14, 15])
|
||||
|
||||
# check that offsets are actually correct for word
|
||||
# H is at 1st position of first word, first L is at 8th position of second word
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 8])
|
||||
# last E is at 6th position of first word, first L is at last (15th) position of second word
|
||||
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [6, 15])
|
||||
|
||||
def test_offsets_batch(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
def check_list_tuples_equal(outputs_batch, outputs_list):
|
||||
self.assertTrue(isinstance(outputs_batch, Wav2Vec2CTCTokenizerOutput))
|
||||
self.assertTrue(isinstance(outputs_list[0], Wav2Vec2CTCTokenizerOutput))
|
||||
|
||||
# transform list to ModelOutput
|
||||
outputs_batch_2 = Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in outputs_list] for k in outputs_list[0]})
|
||||
|
||||
self.assertListEqual(outputs_batch["text"], outputs_batch_2["text"])
|
||||
|
||||
def recursive_check(list_or_dict_1, list_or_dict_2):
|
||||
if isinstance(list_or_dict_1, list):
|
||||
[recursive_check(l1, l2) for l1, l2 in zip(list_or_dict_1, list_or_dict_2)]
|
||||
self.assertEqual(list_or_dict_1, list_or_dict_2)
|
||||
|
||||
if "char_offsets" in outputs_batch:
|
||||
recursive_check(outputs_batch["char_offsets"], outputs_batch_2["char_offsets"])
|
||||
|
||||
if "word_offsets" in outputs_batch:
|
||||
recursive_check(outputs_batch["word_offsets"], outputs_batch_2["word_offsets"])
|
||||
|
||||
# fmt: off
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 4, 8, 98, 32, 32, 32, 32, 4, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, tokenizer.word_delimiter_token_id, 24, 22, 22, 22, 4, 5, 77, tokenizer.pad_token_id, 22, 22, 4, 34, 34, 34, 34],
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# We assume that `decode` works as expected. All we will check now is
|
||||
# the output type is correct and the output is identical to `decode`
|
||||
|
||||
# char
|
||||
outputs_char_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True)
|
||||
outputs_char = [tokenizer.decode(ids, output_char_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_char_batch, outputs_char)
|
||||
|
||||
# word
|
||||
outputs_word_batch = tokenizer.batch_decode(sample_ids, output_word_offsets=True)
|
||||
outputs_word = [tokenizer.decode(ids, output_word_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_word_batch, outputs_word)
|
||||
|
||||
# both
|
||||
outputs_batch = tokenizer.batch_decode(sample_ids, output_char_offsets=True, output_word_offsets=True)
|
||||
outputs = [tokenizer.decode(ids, output_word_offsets=True, output_char_offsets=True) for ids in sample_ids]
|
||||
check_list_tuples_equal(outputs_batch, outputs)
|
||||
|
||||
def test_offsets_integration(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
# pred_ids correspond to the following code
|
||||
# ```
|
||||
# from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
|
||||
# from datasets import load_dataset
|
||||
# import datasets
|
||||
# import torch
|
||||
# model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
# feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
#
|
||||
# ds = load_dataset("common_voice", "en", split="train", streaming=True)
|
||||
# ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
# ds_iter = iter(ds)
|
||||
# sample = next(ds_iter)
|
||||
#
|
||||
# input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
|
||||
# logits = model(input_values).logits
|
||||
# pred_ids = torch.argmax(logits, axis=-1).cpu().tolist()
|
||||
# ```
|
||||
# fmt: off
|
||||
pred_ids = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 11, 0, 0, 0, 22, 0, 0, 4, 4, 4, 14, 0, 0, 0, 0, 0, 8, 8, 0, 5, 5, 0, 12, 0, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 10, 0, 0, 0, 15, 0, 0, 10, 0, 0, 0, 12, 0, 0, 0, 0, 0, 7, 0, 9, 0, 0, 14, 0, 0, 0, 13, 0, 7, 0, 0, 4, 4, 0, 15, 8, 8, 0, 0, 8, 0, 26, 0, 0, 4, 4, 0, 0, 15, 0, 0, 0, 0, 0, 0, 10, 0, 26, 5, 5, 0, 4, 4, 0, 0, 12, 11, 0, 0, 5, 4, 4, 4, 0, 18, 0, 0, 0, 7, 9, 9, 0, 6, 0, 12, 12, 4, 4, 0, 6, 0, 0, 8, 0, 4, 4, 4, 0, 19, 0, 0, 8, 9, 9, 0, 0, 0, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 16, 16, 0, 0, 17, 5, 5, 5, 0, 4, 4, 4, 0, 0, 29, 29, 0, 0, 0, 0, 8, 11, 0, 9, 9, 0, 0, 0, 4, 4, 0, 12, 12, 0, 0, 0, 9, 0, 0, 0, 0, 0, 8, 18, 0, 0, 0, 4, 4, 0, 0, 8, 9, 0, 4, 4, 0, 6, 11, 5, 0, 4, 4, 0, 13, 13, 0, 0, 0, 10, 0, 0, 25, 0, 0, 6, 0, 4, 4, 0, 0, 0, 0, 7, 0, 0, 23, 0, 0, 4, 4, 0, 0, 0, 6, 11, 0, 5, 4, 4, 18, 0, 0, 0, 0, 0, 0, 7, 15, 0, 0, 0, 15, 15, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
|
||||
# wav2vec2-base downsamples input audio by a factor of 320
|
||||
# sampling rate for wav2vec2-base is 16_000
|
||||
time_offset_wav2vec2_base = 320 / 16_000
|
||||
|
||||
expected_char_time_stamps_text = ['W', 'H', 'Y', ' ', 'D', 'O', 'E', 'S', ' ', 'M', 'I', 'L', 'I', 'S', 'A', 'N', 'D', 'R', 'A', ' ', 'L', 'O', 'O', 'K', ' ', 'L', 'I', 'K', 'E', ' ', 'S', 'H', 'E', ' ', 'W', 'A', 'N', 'T', 'S', ' ', 'T', 'O', ' ', 'C', 'O', 'N', 'S', 'U', 'M', 'E', ' ', 'J', 'O', 'H', 'N', ' ', 'S', 'N', 'O', 'W', ' ', 'O', 'N', ' ', 'T', 'H', 'E', ' ', 'R', 'I', 'V', 'T', ' ', 'A', 'P', ' ', 'T', 'H', 'E', ' ', 'W', 'A', 'L', 'L', ' ']
|
||||
expected_char_time_stamps_start = [1.42, 1.44, 1.52, 1.58, 1.64, 1.76, 1.82, 1.88, 1.92, 2.26, 2.32, 2.4, 2.46, 2.54, 2.66, 2.7, 2.76, 2.84, 2.88, 2.94, 3.0, 3.02, 3.1, 3.14, 3.2, 3.28, 3.42, 3.46, 3.48, 3.54, 3.62, 3.64, 3.7, 3.72, 3.8, 3.88, 3.9, 3.96, 4.0, 4.04, 4.1, 4.16, 4.2, 4.28, 4.34, 4.36, 4.48, 4.66, 4.74, 4.76, 4.84, 4.94, 5.06, 5.08, 5.12, 5.22, 5.28, 5.38, 5.5, 5.52, 5.6, 5.68, 5.7, 5.74, 5.8, 5.82, 5.84, 5.88, 5.94, 6.04, 6.1, 6.16, 6.2, 6.32, 6.38, 6.44, 6.54, 6.56, 6.6, 6.62, 6.66, 6.8, 6.82, 6.9, 6.96]
|
||||
expected_char_time_stamps_end = [1.44, 1.46, 1.54, 1.64, 1.66, 1.8, 1.86, 1.9, 2.06, 2.28, 2.34, 2.42, 2.48, 2.56, 2.68, 2.72, 2.78, 2.86, 2.9, 2.98, 3.02, 3.06, 3.12, 3.16, 3.24, 3.3, 3.44, 3.48, 3.52, 3.58, 3.64, 3.66, 3.72, 3.78, 3.82, 3.9, 3.94, 3.98, 4.04, 4.08, 4.12, 4.18, 4.26, 4.3, 4.36, 4.4, 4.52, 4.7, 4.76, 4.82, 4.9, 4.98, 5.08, 5.1, 5.16, 5.26, 5.32, 5.4, 5.52, 5.54, 5.64, 5.7, 5.72, 5.78, 5.82, 5.84, 5.86, 5.92, 5.98, 6.06, 6.12, 6.18, 6.24, 6.34, 6.4, 6.48, 6.56, 6.58, 6.62, 6.66, 6.68, 6.82, 6.84, 6.94, 7.02]
|
||||
|
||||
expected_word_time_stamps_text = ['WHY', 'DOES', 'MILISANDRA', 'LOOK', 'LIKE', 'SHE', 'WANTS', 'TO', 'CONSUME', 'JOHN', 'SNOW', 'ON', 'THE', 'RIVT', 'AP', 'THE', 'WALL']
|
||||
expected_word_time_stamps_start = [1.42, 1.64, 2.26, 3.0, 3.28, 3.62, 3.8, 4.1, 4.28, 4.94, 5.28, 5.68, 5.8, 5.94, 6.32, 6.54, 6.66]
|
||||
expected_word_time_stamps_end = [1.54, 1.9, 2.9, 3.16, 3.52, 3.72, 4.04, 4.18, 4.82, 5.16, 5.54, 5.72, 5.86, 6.18, 6.4, 6.62, 6.94]
|
||||
# fmt: on
|
||||
|
||||
output = tokenizer.batch_decode(pred_ids, output_char_offsets=True, output_word_offsets=True)
|
||||
|
||||
char_offsets_text = self.get_from_offsets(output["char_offsets"][0], "char")
|
||||
char_offsets_start = self.get_from_offsets(output["char_offsets"][0], "start_offset")
|
||||
char_offsets_end = self.get_from_offsets(output["char_offsets"][0], "end_offset")
|
||||
|
||||
word_offsets_text = self.get_from_offsets(output["word_offsets"][0], "word")
|
||||
word_offsets_start = self.get_from_offsets(output["word_offsets"][0], "start_offset")
|
||||
word_offsets_end = self.get_from_offsets(output["word_offsets"][0], "end_offset")
|
||||
|
||||
# let's transform offsets to time stamps in seconds
|
||||
char_time_stamps_start = [round(c * time_offset_wav2vec2_base, 2) for c in char_offsets_start]
|
||||
char_time_stamps_end = [round(c * time_offset_wav2vec2_base, 2) for c in char_offsets_end]
|
||||
|
||||
word_time_stamps_start = [round(w * time_offset_wav2vec2_base, 2) for w in word_offsets_start]
|
||||
word_time_stamps_end = [round(w * time_offset_wav2vec2_base, 2) for w in word_offsets_end]
|
||||
|
||||
# NOTE: you can verify the above results by checking out the dataset viewer
|
||||
# on https://huggingface.co/datasets/common_voice/viewer/en/train and
|
||||
# downloading / playing the sample `common_voice_en_100038.mp3`. As
|
||||
# you can hear the time-stamps match more or less
|
||||
|
||||
self.assertListEqual(expected_char_time_stamps_text, char_offsets_text)
|
||||
self.assertListEqual(expected_char_time_stamps_start, char_time_stamps_start)
|
||||
self.assertListEqual(expected_char_time_stamps_end, char_time_stamps_end)
|
||||
|
||||
self.assertListEqual(expected_word_time_stamps_text, word_offsets_text)
|
||||
self.assertListEqual(expected_word_time_stamps_start, word_time_stamps_start)
|
||||
self.assertListEqual(expected_word_time_stamps_end, word_time_stamps_end)
|
||||
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2Model has no max model length => no testing
|
||||
pass
|
||||
|
||||
# overwrite from test_tokenization_common
|
||||
def test_add_tokens_tokenizer(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
vocab_size = tokenizer.vocab_size
|
||||
all_size = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size, 0)
|
||||
|
||||
# We usually have added tokens from the start in tests because our vocab fixtures are
|
||||
# smaller than the original vocabs - let's not assert this
|
||||
# self.assertEqual(vocab_size, all_size)
|
||||
|
||||
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
|
||||
added_toks = tokenizer.add_tokens(new_toks)
|
||||
vocab_size_2 = tokenizer.vocab_size
|
||||
all_size_2 = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size_2, 0)
|
||||
self.assertEqual(vocab_size, vocab_size_2)
|
||||
self.assertEqual(added_toks, len(new_toks))
|
||||
self.assertEqual(all_size_2, all_size + len(new_toks))
|
||||
|
||||
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
|
||||
|
||||
self.assertGreaterEqual(len(tokens), 4)
|
||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
|
||||
|
||||
new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
|
||||
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||||
vocab_size_3 = tokenizer.vocab_size
|
||||
all_size_3 = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size_3, 0)
|
||||
self.assertEqual(vocab_size, vocab_size_3)
|
||||
self.assertEqual(added_toks_2, len(new_toks_2))
|
||||
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||
|
||||
tokens = tokenizer.encode(
|
||||
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
|
||||
)
|
||||
|
||||
self.assertGreaterEqual(len(tokens), 6)
|
||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[0], tokens[1])
|
||||
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[-3], tokens[-4])
|
||||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||
self.assertEqual(tokens[-3], tokenizer.pad_token_id)
|
||||
|
||||
@unittest.skip("The tokenizer shouldn't be used to encode input IDs (except for labels), only to decode.")
|
||||
def test_tf_encode_plus_sent_to_model(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("The tokenizer shouldn't be used to encode input IDs (except for labels), only to decode.")
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user