[Wav2Vec2] Improve Tokenizer & Model for batched inference (#10117)

* save intermediate

* finish batch the same as fairseq

* add normalization

* fix batched input

* add better comment

* Update src/transformers/models/wav2vec2/modeling_wav2vec2.py

* add nice docstring

* add tokenizer tests

* make all slow tests pass

* finish PR

* correct import
This commit is contained in:
Patrick von Platen
2021-02-11 15:40:54 +03:00
committed by GitHub
parent 2f3b5f4dcc
commit 495c157d6f
4 changed files with 227 additions and 27 deletions

View File

@@ -18,7 +18,7 @@
import math
import unittest
from tests.test_modeling_common import floats_tensor
from tests.test_modeling_common import floats_tensor, random_attention_mask
from transformers import is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
@@ -93,6 +93,7 @@ class Wav2Vec2ModelTester:
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = Wav2Vec2Config(
hidden_size=self.hidden_size,
@@ -115,20 +116,48 @@ class Wav2Vec2ModelTester:
vocab_size=self.vocab_size,
)
return config, input_values
return config, input_values, attention_mask
def create_and_check_model(self, config, input_values):
def create_and_check_model(self, config, input_values, attention_mask):
model = Wav2Vec2Model(config=config)
model.to(torch_device)
model.eval()
result = model(input_values)
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def create_and_check_batch_inference(self, config, input_values, *args):
# Not sure how to make this test pass at the moment. Batched input yields
# same results as official fairseq implementation, but gives different results
# depending on whether batched input is used or not
# check: https://github.com/pytorch/fairseq/issues/3227
model = Wav2Vec2Model(config=config)
model.to(torch_device)
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0.0
batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state
for i in range(input_values.shape[0]):
input_slice = input_values[i : i + 1, : input_lengths[i]]
output = model(input_slice).last_hidden_state
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
def prepare_config_and_inputs_for_common(self):
config, input_values = self.prepare_config_and_inputs()
inputs_dict = {"input_values": input_values}
config, input_values, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
return config, inputs_dict
@@ -222,6 +251,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_batched_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
def test_inputs_embeds(self):
pass
@@ -288,7 +321,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
return ds["speech"][:num_samples]
def test_inference_masked_lm_normal(self):
def test_inference_ctc_normal(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
@@ -306,16 +339,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_masked_lm_normal_batched(self):
def test_inference_ctc_normal_batched(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(2)
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
torch_device
)
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
input_values = inputs.input_values.to(torch_device)
with torch.no_grad():
logits = model(input_values).logits
@@ -329,18 +362,19 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_masked_lm_robust_batched(self):
def test_inference_ctc_robust_batched(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
input_speech = self._load_datasamples(4)
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
torch_device
)
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
logits = model(input_values).logits
logits = model(input_values, attention_mask=attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = tokenizer.batch_decode(predicted_ids)

View File

@@ -23,7 +23,10 @@ import unittest
import numpy as np
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2Tokenizer
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Tokenizer
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.testing_utils import slow
global_rng = random.Random()
@@ -299,3 +302,46 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
for parameter_name, parameter in signature.parameters.items():
if parameter.default != inspect.Parameter.empty:
self.assertIn(parameter_name, tokenizer.init_kwargs)
def test_zero_mean_unit_variance_normalization(self):
tokenizer = self.get_tokenizer(do_normalize=True)
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = tokenizer(speech_inputs, padding="longest")
input_values = processed.input_values
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
_check_zero_mean_unit_variance(input_values[0, :800])
_check_zero_mean_unit_variance(input_values[1, :1000])
_check_zero_mean_unit_variance(input_values[2])
def test_return_attention_mask(self):
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
# default case -> no attention_mask is returned
tokenizer = self.get_tokenizer()
processed = tokenizer(speech_inputs)
self.assertNotIn("attention_mask", processed)
# wav2vec2-lv60 -> return attention_mask
tokenizer = self.get_tokenizer(return_attention_mask=True)
processed = tokenizer(speech_inputs, padding="longest")
self.assertIn("attention_mask", processed)
self.assertListEqual(list(processed.attention_mask.shape), list(processed.input_values.shape))
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), [800, 1000, 1200])
@slow
def test_pretrained_checkpoints_are_set_correctly(self):
# this test makes sure that models that are using
# group norm don't have their tokenizer return the
# attention_mask
for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST:
config = Wav2Vec2Config.from_pretrained(model_id)
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_id)
# only "layer" feature extraction norm should make use of
# attention_mask
self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")