Add SpeechEncoderDecoder & Speech2Text2 (#13186)

* fix_torch_device_generate_test

* remove @

* up

* correct some bugs

* correct model

* finish speech2text extension

* up

* up

* up

* up

* Update utils/custom_init_isort.py

* up

* up

* update with tokenizer

* correct old tok

* correct old tok

* fix bug

* up

* up

* add more tests

* up

* fix docs

* up

* fix some more tests

* add better config

* correct some more things
"

* fix tests

* improve docs

* Apply suggestions from code review

* Apply suggestions from code review

* final fixes

* finalize

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* apply suggestions Lysandre and Sylvain

* apply nicos suggestions

* upload everything

* finish

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
Co-authored-by: your_github_username <your_github_email>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-09-01 13:33:31 +02:00
committed by GitHub
parent 9396b40433
commit 0b8c84e110
30 changed files with 3649 additions and 73 deletions

View File

@@ -241,11 +241,15 @@ class Speech2TextModelTester:
decoder.save_pretrained(tmpdirname)
decoder = Speech2TextDecoder.from_pretrained(tmpdirname).to(torch_device)
encoder_attention_mask = encoder._get_feature_vector_attention_mask(
encoder_last_hidden_state.shape[1], inputs_dict["attention_mask"]
)
last_hidden_state_2 = decoder(
input_ids=inputs_dict["decoder_input_ids"],
attention_mask=inputs_dict["decoder_attention_mask"],
encoder_hidden_states=encoder_last_hidden_state,
encoder_attention_mask=inputs_dict["attention_mask"],
encoder_attention_mask=encoder_attention_mask,
)[0]
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
@@ -288,6 +292,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
# not implemented currently
def test_inputs_embeds(self):
pass
@@ -352,7 +357,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
else:
seq_length = self.model_tester.seq_length
subsampled_seq_length = model._get_subsampled_output_lengths(seq_length)
subsampled_seq_length = model._get_feat_extract_output_lengths(seq_length)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
@@ -402,8 +407,8 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
model.to(torch_device)
model.eval()
subsampled_encoder_seq_length = model._get_subsampled_output_lengths(encoder_seq_length)
subsampled_encoder_key_length = model._get_subsampled_output_lengths(encoder_key_length)
subsampled_encoder_seq_length = model._get_feat_extract_output_lengths(encoder_seq_length)
subsampled_encoder_key_length = model._get_feat_extract_output_lengths(encoder_key_length)
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))