Fix usage of additional kwargs in from_encoder_decoder_pretrained in encoder-decoder models (#15056)

* [EncoderDecoder] Add test for usage of extra kwargs

* [EncoderDecoder] Fix usage of extra kwargs in from pretrained

* [EncoderDecoder] apply suggested changes (passing **kwargs_encoder)

* [EncoderDecoder] create new test function and make sure it passes

Co-authored-by: jonas <jsnfly@gmx.de>
This commit is contained in:
jsnfly
2022-01-19 23:00:33 +01:00
committed by GitHub
parent 3fefee9910
commit baf1ebe9f0
4 changed files with 70 additions and 6 deletions

View File

@@ -142,6 +142,48 @@ class EncoderDecoderMixin:
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_encoder_decoder_model_from_pretrained_using_model_paths(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
encoder_model.save_pretrained(encoder_tmp_dirname)
decoder_model.save_pretrained(decoder_tmp_dirname)
model_kwargs = {"encoder_hidden_dropout_prob": 0.0}
# BartConfig has no hidden_dropout_prob.
if not hasattr(decoder_config, "hidden_dropout_prob"):
model_kwargs["decoder_activation_function"] = "gelu"
else:
model_kwargs["decoder_hidden_dropout_prob"] = 0.0
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, **model_kwargs
)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_encoder_decoder_model_from_pretrained(
self,
config,
@@ -459,6 +501,10 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
def test_encoder_decoder_model_from_pretrained_using_model_paths(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_using_model_paths(**input_ids_dict, return_dict=False)
def test_save_and_load_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**input_ids_dict)