[PretrainedFeatureExtractor] + Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2Tokenizer (#10324)
* push to show * small improvement * small improvement * Update src/transformers/feature_extraction_utils.py * Update src/transformers/feature_extraction_utils.py * implement base * add common tests * make all tests pass for wav2vec2 * make padding work & add more tests * finalize feature extractor utils * add call method to feature extraction * finalize feature processor * finish tokenizer * finish general processor design * finish tests * typo * remove bogus file * finish docstring * add docs * finish docs * small fix * correct docs * save intermediate * load changes * apply changes * apply changes to doc * change tests * apply surajs recommend * final changes * Apply suggestions from code review * fix typo * fix import * correct docstring
This commit is contained in:
committed by
GitHub
parent
9dc7825744
commit
cb38ffcc5e
284
tests/test_feature_extraction_common.py
Normal file
284
tests/test_feature_extraction_common.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BatchFeature
|
||||
from transformers.testing_utils import require_tf, require_torch
|
||||
|
||||
|
||||
class FeatureExtractionMixin:
|
||||
|
||||
# to overwrite at feature extractactor specific tests
|
||||
feat_extract_tester = None
|
||||
feature_extraction_class = None
|
||||
|
||||
@property
|
||||
def feat_extract_dict(self):
|
||||
return self.feat_extract_tester.prepare_feat_extract_dict()
|
||||
|
||||
def test_feat_extract_common_properties(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
self.assertTrue(hasattr(feat_extract, "feature_size"))
|
||||
self.assertTrue(hasattr(feat_extract, "sampling_rate"))
|
||||
self.assertTrue(hasattr(feat_extract, "padding_value"))
|
||||
|
||||
def test_feat_extract_to_json_string(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
obj = json.loads(feat_extract.to_json_string())
|
||||
for key, value in self.feat_extract_dict.items():
|
||||
self.assertEqual(obj[key], value)
|
||||
|
||||
def test_feat_extract_to_json_file(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
|
||||
feat_extract_first.to_json_file(json_file_path)
|
||||
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
|
||||
|
||||
self.assertEqual(feat_extract_second.to_dict(), feat_extract_first.to_dict())
|
||||
|
||||
def test_feat_extract_from_and_save_pretrained(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
feat_extract_first.save_pretrained(tmpdirname)
|
||||
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertEqual(feat_extract_second.to_dict(), feat_extract_first.to_dict())
|
||||
|
||||
def test_init_without_params(self):
|
||||
feat_extract = self.feature_extraction_class()
|
||||
self.assertIsNotNone(feat_extract)
|
||||
|
||||
def test_batch_feature(self):
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
input_name = feat_extract.model_input_names[0]
|
||||
|
||||
processed_features = BatchFeature({input_name: speech_inputs})
|
||||
|
||||
self.assertTrue(all(len(x) == len(y) for x, y in zip(speech_inputs, processed_features[input_name])))
|
||||
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
|
||||
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="np")
|
||||
|
||||
batch_features_input = processed_features[input_name]
|
||||
|
||||
if len(batch_features_input.shape) < 3:
|
||||
batch_features_input = batch_features_input[:, :, None]
|
||||
|
||||
self.assertTrue(
|
||||
batch_features_input.shape
|
||||
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_batch_feature_pt(self):
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
input_name = feat_extract.model_input_names[0]
|
||||
|
||||
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="pt")
|
||||
|
||||
batch_features_input = processed_features[input_name]
|
||||
|
||||
if len(batch_features_input.shape) < 3:
|
||||
batch_features_input = batch_features_input[:, :, None]
|
||||
|
||||
self.assertTrue(
|
||||
batch_features_input.shape
|
||||
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_batch_feature_tf(self):
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
input_name = feat_extract.model_input_names[0]
|
||||
|
||||
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="tf")
|
||||
|
||||
batch_features_input = processed_features[input_name]
|
||||
|
||||
if len(batch_features_input.shape) < 3:
|
||||
batch_features_input = batch_features_input[:, :, None]
|
||||
|
||||
self.assertTrue(
|
||||
batch_features_input.shape
|
||||
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
|
||||
)
|
||||
|
||||
def _check_padding(self, numpify=False):
|
||||
def _inputs_have_equal_length(input):
|
||||
length = len(input[0])
|
||||
for input_slice in input[1:]:
|
||||
if len(input_slice) != length:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _inputs_are_equal(input_1, input_2):
|
||||
if len(input_1) != len(input_2):
|
||||
return False
|
||||
|
||||
for input_slice_1, input_slice_2 in zip(input_1, input_2):
|
||||
if not np.allclose(np.asarray(input_slice_1), np.asarray(input_slice_2), atol=1e-3):
|
||||
return False
|
||||
return True
|
||||
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify)
|
||||
input_name = feat_extract.model_input_names[0]
|
||||
|
||||
processed_features = BatchFeature({input_name: speech_inputs})
|
||||
|
||||
pad_diff = self.feat_extract_tester.seq_length_diff
|
||||
pad_max_length = self.feat_extract_tester.max_seq_length + pad_diff
|
||||
pad_min_length = self.feat_extract_tester.min_seq_length
|
||||
batch_size = self.feat_extract_tester.batch_size
|
||||
feature_size = self.feat_extract_tester.feature_size
|
||||
|
||||
# test padding for List[int] + numpy
|
||||
input_1 = feat_extract.pad(processed_features, padding=False)[input_name]
|
||||
input_2 = feat_extract.pad(processed_features, padding="longest")[input_name]
|
||||
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))[
|
||||
input_name
|
||||
]
|
||||
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
|
||||
|
||||
# max_length parameter has to be provided when setting `padding="max_length"`
|
||||
with self.assertRaises(ValueError):
|
||||
feat_extract.pad(processed_features, padding="max_length")[input_name]
|
||||
|
||||
input_5 = feat_extract.pad(
|
||||
processed_features, padding="max_length", max_length=pad_max_length, return_tensors="np"
|
||||
)[input_name]
|
||||
|
||||
self.assertFalse(_inputs_have_equal_length(input_1))
|
||||
self.assertTrue(_inputs_have_equal_length(input_2))
|
||||
self.assertTrue(_inputs_have_equal_length(input_3))
|
||||
self.assertTrue(_inputs_are_equal(input_2, input_3))
|
||||
self.assertTrue(len(input_1[0]) == pad_min_length)
|
||||
self.assertTrue(len(input_1[1]) == pad_min_length + pad_diff)
|
||||
self.assertTrue(input_4.shape[:2] == (batch_size, len(input_3[0])))
|
||||
self.assertTrue(input_5.shape[:2] == (batch_size, pad_max_length))
|
||||
|
||||
if feature_size > 1:
|
||||
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)
|
||||
|
||||
# test padding for `pad_to_multiple_of` for List[int] + numpy
|
||||
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)[input_name]
|
||||
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)[input_name]
|
||||
input_8 = feat_extract.pad(
|
||||
processed_features, padding="max_length", pad_to_multiple_of=10, max_length=pad_max_length
|
||||
)[input_name]
|
||||
input_9 = feat_extract.pad(
|
||||
processed_features,
|
||||
padding="max_length",
|
||||
pad_to_multiple_of=10,
|
||||
max_length=pad_max_length,
|
||||
return_tensors="np",
|
||||
)[input_name]
|
||||
|
||||
self.assertTrue(all(len(x) % 10 == 0 for x in input_6))
|
||||
self.assertTrue(_inputs_are_equal(input_6, input_7))
|
||||
|
||||
expected_mult_pad_length = pad_max_length if pad_max_length % 10 == 0 else (pad_max_length // 10 + 1) * 10
|
||||
self.assertTrue(all(len(x) == expected_mult_pad_length for x in input_8))
|
||||
self.assertTrue(input_9.shape[:2], (batch_size, expected_mult_pad_length))
|
||||
|
||||
if feature_size > 1:
|
||||
self.assertTrue(input_9.shape[2] == feature_size)
|
||||
|
||||
# Check padding value is correct
|
||||
padding_vector_sum = (np.ones(self.feat_extract_tester.feature_size) * feat_extract.padding_value).sum()
|
||||
self.assertTrue(
|
||||
abs(np.asarray(input_2[0])[pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length))
|
||||
< 1e-3
|
||||
)
|
||||
self.assertTrue(
|
||||
abs(
|
||||
np.asarray(input_2[1])[pad_min_length + pad_diff :].sum()
|
||||
- padding_vector_sum * (pad_max_length - pad_min_length - pad_diff)
|
||||
)
|
||||
< 1e-3
|
||||
)
|
||||
self.assertTrue(
|
||||
abs(
|
||||
np.asarray(input_2[2])[pad_min_length + 2 * pad_diff :].sum()
|
||||
- padding_vector_sum * (pad_max_length - pad_min_length - 2 * pad_diff)
|
||||
)
|
||||
< 1e-3
|
||||
)
|
||||
self.assertTrue(
|
||||
abs(input_5[0, pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length)) < 1e-3
|
||||
)
|
||||
self.assertTrue(
|
||||
abs(input_9[0, pad_min_length:].sum() - padding_vector_sum * (expected_mult_pad_length - pad_min_length))
|
||||
< 1e-3
|
||||
)
|
||||
|
||||
def test_padding_from_list(self):
|
||||
self._check_padding(numpify=False)
|
||||
|
||||
def test_padding_from_array(self):
|
||||
self._check_padding(numpify=True)
|
||||
|
||||
@require_torch
|
||||
def test_padding_accepts_tensors_pt(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
|
||||
input_name = feat_extract.model_input_names[0]
|
||||
|
||||
processed_features = BatchFeature({input_name: speech_inputs})
|
||||
|
||||
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
|
||||
input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name]
|
||||
|
||||
self.assertTrue(abs(input_np.sum() - input_pt.numpy().sum()) < 1e-2)
|
||||
|
||||
@require_tf
|
||||
def test_padding_accepts_tensors_tf(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
|
||||
input_name = feat_extract.model_input_names[0]
|
||||
|
||||
processed_features = BatchFeature({input_name: speech_inputs})
|
||||
|
||||
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
|
||||
input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name]
|
||||
|
||||
self.assertTrue(abs(input_np.sum() - input_tf.numpy().sum()) < 1e-2)
|
||||
|
||||
def test_attention_mask(self):
|
||||
feat_dict = self.feat_extract_dict
|
||||
feat_dict["return_attention_mask"] = True
|
||||
feat_extract = self.feature_extraction_class(**feat_dict)
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
|
||||
input_lenghts = [len(x) for x in speech_inputs]
|
||||
input_name = feat_extract.model_input_names[0]
|
||||
|
||||
processed = BatchFeature({input_name: speech_inputs})
|
||||
|
||||
processed = feat_extract.pad(processed, padding="longest", return_tensors="np")
|
||||
self.assertIn("attention_mask", processed)
|
||||
self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
|
||||
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lenghts)
|
||||
147
tests/test_feature_extraction_wav2vec2.py
Normal file
147
tests/test_feature_extraction_wav2vec2.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from .test_feature_extraction_common import FeatureExtractionMixin
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
values = []
|
||||
for batch_idx in range(shape[0]):
|
||||
values.append([])
|
||||
for _ in range(shape[1]):
|
||||
values[-1].append(rng.random() * scale)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class Wav2Vec2FeatureExtractionTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
min_seq_length=400,
|
||||
max_seq_length=2000,
|
||||
feature_size=1,
|
||||
padding_value=0.0,
|
||||
sampling_rate=16000,
|
||||
return_attention_mask=True,
|
||||
do_normalize=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.min_seq_length = min_seq_length
|
||||
self.max_seq_length = max_seq_length
|
||||
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
|
||||
self.feature_size = feature_size
|
||||
self.padding_value = padding_value
|
||||
self.sampling_rate = sampling_rate
|
||||
self.return_attention_mask = return_attention_mask
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
"feature_size": self.feature_size,
|
||||
"padding_value": self.padding_value,
|
||||
"sampling_rate": self.sampling_rate,
|
||||
"return_attention_mask": self.return_attention_mask,
|
||||
"do_normalize": self.do_normalize,
|
||||
}
|
||||
|
||||
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
|
||||
def _flatten(list_of_lists):
|
||||
return list(itertools.chain(*list_of_lists))
|
||||
|
||||
if equal_length:
|
||||
speech_inputs = floats_list((self.batch_size, self.max_seq_length))
|
||||
else:
|
||||
speech_inputs = [
|
||||
_flatten(floats_list((x, self.feature_size)))
|
||||
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
||||
]
|
||||
|
||||
if numpify:
|
||||
speech_inputs = [np.asarray(x) for x in speech_inputs]
|
||||
|
||||
return speech_inputs
|
||||
|
||||
|
||||
class Wav2Vec2FeatureExtractionTest(FeatureExtractionMixin, unittest.TestCase):
|
||||
|
||||
feature_extraction_class = Wav2Vec2FeatureExtractor
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = Wav2Vec2FeatureExtractionTester(self)
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
# create three inputs of length 800, 1000, and 1200
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||
|
||||
# Test not batched input
|
||||
encoded_sequences_1 = feat_extract(speech_inputs[0], return_tensors="np").input_values
|
||||
encoded_sequences_2 = feat_extract(np_speech_inputs[0], return_tensors="np").input_values
|
||||
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
|
||||
|
||||
# Test batched
|
||||
encoded_sequences_1 = feat_extract(speech_inputs, return_tensors="np").input_values
|
||||
encoded_sequences_2 = feat_extract(np_speech_inputs, return_tensors="np").input_values
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_zero_mean_unit_variance_normalization(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = feat_extract(speech_inputs, padding="longest")
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0, :800])
|
||||
_check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
_check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
@slow
|
||||
def test_pretrained_checkpoints_are_set_correctly(self):
|
||||
# this test makes sure that models that are using
|
||||
# group norm don't have their feature extractor return the
|
||||
# attention_mask
|
||||
for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST:
|
||||
config = Wav2Vec2Config.from_pretrained(model_id)
|
||||
feat_extract = Wav2Vec2FeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
# only "layer" feature extraction norm should make use of
|
||||
# attention_mask
|
||||
self.assertEqual(feat_extract.return_attention_mask, config.feat_extract_norm == "layer")
|
||||
@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor
|
||||
|
||||
|
||||
class Wav2Vec2ModelTester:
|
||||
@@ -324,17 +324,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_ctc_normal(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
input_values = tokenizer(input_speech, return_tensors="pt").input_values.to(torch_device)
|
||||
input_values = processor(input_speech, return_tensors="pt").input_values.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
@@ -342,11 +341,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_ctc_normal_batched(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
|
||||
@@ -354,7 +353,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
logits = model(input_values).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
@@ -364,11 +363,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_inference_ctc_robust_batched(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
||||
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
@@ -377,7 +376,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
logits = model(input_values, attention_mask=attention_mask).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
|
||||
@@ -16,9 +16,9 @@ from typing import List, Optional
|
||||
from unittest import mock
|
||||
|
||||
from transformers import is_tf_available, is_torch_available, pipeline
|
||||
from transformers.file_utils import to_py_obj
|
||||
from transformers.pipelines import Pipeline
|
||||
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
|
||||
from transformers.tokenization_utils_base import to_py_obj
|
||||
|
||||
|
||||
VALID_INPUTS = ["A simple string", ["list of strings"]]
|
||||
|
||||
139
tests/test_processor_wav2vec2.py
Normal file
139
tests/test_processor_wav2vec2.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
|
||||
from .test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
|
||||
class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
|
||||
self.add_kwargs_tokens_map = {
|
||||
"pad_token": "<pad>",
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
}
|
||||
feature_extractor_map = {
|
||||
"feature_size": 1,
|
||||
"padding_value": 0.0,
|
||||
"sampling_rate": 16000,
|
||||
"return_attention_mask": False,
|
||||
"do_normalize": True,
|
||||
}
|
||||
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
|
||||
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(feature_extractor_map) + "\n")
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.add_kwargs_tokens_map)
|
||||
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return Wav2Vec2FeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = Wav2Vec2Processor.from_pretrained(self.tmpdirname)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = Wav2Vec2Processor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
|
||||
|
||||
processor = Wav2Vec2Processor.from_pretrained(
|
||||
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, Wav2Vec2CTCTokenizer)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_feature_extractor(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
raw_speech = floats_list((3, 1000))
|
||||
|
||||
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
|
||||
input_processor = processor(raw_speech, return_tensors="np")
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
def test_tokenizer(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "This is a test string"
|
||||
|
||||
with processor.as_target_processor():
|
||||
encoded_processor = processor(input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||
|
||||
decoded_processor = processor.batch_decode(predicted_ids)
|
||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||
|
||||
self.assertListEqual(decoded_tok, decoded_processor)
|
||||
@@ -23,11 +23,17 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Tokenizer
|
||||
from transformers import (
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2CTCTokenizer,
|
||||
Wav2Vec2Tokenizer,
|
||||
)
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
@@ -345,3 +351,101 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
|
||||
# only "layer" feature extraction norm should make use of
|
||||
# attention_mask
|
||||
self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")
|
||||
|
||||
|
||||
class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = Wav2Vec2CTCTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
|
||||
self.special_tokens_map = {"pad_token": "<pad>", "unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||
]
|
||||
tokens = tokenizer.decode(sample_ids[0])
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
self.assertEqual(tokens, batch_tokens[0])
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
|
||||
|
||||
def test_tokenizer_decode_special(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||
]
|
||||
sample_ids_2 = [
|
||||
[11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[
|
||||
24,
|
||||
22,
|
||||
5,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
24,
|
||||
22,
|
||||
5,
|
||||
77,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
],
|
||||
]
|
||||
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
|
||||
self.assertEqual(batch_tokens, batch_tokens_2)
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
|
||||
|
||||
def test_tokenizer_decode_added_tokens(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
tokenizer.add_tokens(["!", "?"])
|
||||
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
||||
|
||||
sample_ids = [
|
||||
[
|
||||
11,
|
||||
5,
|
||||
15,
|
||||
tokenizer.pad_token_id,
|
||||
15,
|
||||
8,
|
||||
98,
|
||||
32,
|
||||
32,
|
||||
33,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
32,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
34,
|
||||
],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
||||
]
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2Model has no max model length => no
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user