Adding TFWav2Vec2Model (#11617)

* [WIP] Add TFWav2Vec2Model

Work in progress for adding a tensorflow version of Wav2Vec2

* feedback changes

* small fix

* Test Feedback Round 1

* Add SpecAugment and CTC Loss

* correct spec augment mask creation

* docstring and correct copyright

* correct bugs

* remove bogus file

* finish tests correction

* del unnecessary layers

* Update src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style

* correct final bug

* Feedback Changes

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Will Rice
2021-06-14 13:58:54 -04:00
committed by GitHub
parent 1ed2ebf60d
commit d438eee030
12 changed files with 2250 additions and 13 deletions

View File

@@ -445,6 +445,8 @@ class TFModelTesterMixin:
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
if type(key) == bool:
pt_inputs_dict[name] = key
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
@@ -455,6 +457,7 @@ class TFModelTesterMixin:
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
@@ -486,6 +489,8 @@ class TFModelTesterMixin:
if type(key) == bool:
key = np.array(key, dtype=bool)
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch
@@ -1061,7 +1066,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
input_ids = inputs_dict.get("input_ids", None)
# iterate over all generative models
for model_class in self.all_generative_model_classes:
@@ -1097,7 +1102,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
input_ids = inputs_dict.get("input_ids", None)
for model_class in self.all_generative_model_classes:
model = model_class(config)