Adding TFWav2Vec2Model (#11617)
* [WIP] Add TFWav2Vec2Model Work in progress for adding a tensorflow version of Wav2Vec2 * feedback changes * small fix * Test Feedback Round 1 * Add SpecAugment and CTC Loss * correct spec augment mask creation * docstring and correct copyright * correct bugs * remove bogus file * finish tests correction * del unnecessary layers * Update src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style * correct final bug * Feedback Changes Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -445,6 +445,8 @@ class TFModelTesterMixin:
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
|
||||
if type(key) == bool:
|
||||
pt_inputs_dict[name] = key
|
||||
elif name == "input_values":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
|
||||
@@ -455,6 +457,7 @@ class TFModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs_dict)
|
||||
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
|
||||
|
||||
tf_hidden_states = tfo[0].numpy()
|
||||
pt_hidden_states = pto[0].numpy()
|
||||
|
||||
@@ -486,6 +489,8 @@ class TFModelTesterMixin:
|
||||
if type(key) == bool:
|
||||
key = np.array(key, dtype=bool)
|
||||
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
|
||||
elif name == "input_values":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
@@ -1061,7 +1066,7 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
input_ids = inputs_dict.get("input_ids", None)
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -1097,7 +1102,7 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
input_ids = inputs_dict.get("input_ids", None)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
Reference in New Issue
Block a user