[Wav2Vec2] Improve Tokenizer & Model for batched inference (#10117)
* save intermediate * finish batch the same as fairseq * add normalization * fix batched input * add better comment * Update src/transformers/models/wav2vec2/modeling_wav2vec2.py * add nice docstring * add tokenizer tests * make all slow tests pass * finish PR * correct import
This commit is contained in:
committed by
GitHub
parent
2f3b5f4dcc
commit
495c157d6f
@@ -23,7 +23,10 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2Tokenizer
|
||||
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Tokenizer
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
@@ -299,3 +302,46 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
|
||||
for parameter_name, parameter in signature.parameters.items():
|
||||
if parameter.default != inspect.Parameter.empty:
|
||||
self.assertIn(parameter_name, tokenizer.init_kwargs)
|
||||
|
||||
def test_zero_mean_unit_variance_normalization(self):
|
||||
tokenizer = self.get_tokenizer(do_normalize=True)
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = tokenizer(speech_inputs, padding="longest")
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0, :800])
|
||||
_check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
_check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
def test_return_attention_mask(self):
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
|
||||
# default case -> no attention_mask is returned
|
||||
tokenizer = self.get_tokenizer()
|
||||
processed = tokenizer(speech_inputs)
|
||||
self.assertNotIn("attention_mask", processed)
|
||||
|
||||
# wav2vec2-lv60 -> return attention_mask
|
||||
tokenizer = self.get_tokenizer(return_attention_mask=True)
|
||||
processed = tokenizer(speech_inputs, padding="longest")
|
||||
|
||||
self.assertIn("attention_mask", processed)
|
||||
self.assertListEqual(list(processed.attention_mask.shape), list(processed.input_values.shape))
|
||||
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), [800, 1000, 1200])
|
||||
|
||||
@slow
|
||||
def test_pretrained_checkpoints_are_set_correctly(self):
|
||||
# this test makes sure that models that are using
|
||||
# group norm don't have their tokenizer return the
|
||||
# attention_mask
|
||||
for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST:
|
||||
config = Wav2Vec2Config.from_pretrained(model_id)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_id)
|
||||
|
||||
# only "layer" feature extraction norm should make use of
|
||||
# attention_mask
|
||||
self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")
|
||||
|
||||
Reference in New Issue
Block a user