From baf1ebe9f033b0d8cd1ca095ba68f6789ff3d505 Mon Sep 17 00:00:00 2001 From: jsnfly <37632631+jsnfly@users.noreply.github.com> Date: Wed, 19 Jan 2022 23:00:33 +0100 Subject: [PATCH] Fix usage of additional kwargs in `from_encoder_decoder_pretrained` in encoder-decoder models (#15056) * [EncoderDecoder] Add test for usage of extra kwargs * [EncoderDecoder] Fix usage of extra kwargs in from pretrained * [EncoderDecoder] apply suggested changes (passing **kwargs_encoder) * [EncoderDecoder] create new test function and make sure it passes Co-authored-by: jonas --- .../modeling_encoder_decoder.py | 10 +++- .../modeling_speech_encoder_decoder.py | 10 +++- .../modeling_vision_encoder_decoder.py | 10 +++- tests/test_modeling_encoder_decoder.py | 46 +++++++++++++++++++ 4 files changed, 70 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index f829773ca0..3891ad7d4d 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -371,7 +371,10 @@ class EncoderDecoderModel(PreTrainedModel): ) if "config" not in kwargs_encoder: - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " @@ -393,7 +396,10 @@ class EncoderDecoderModel(PreTrainedModel): ) if "config" not in kwargs_decoder: - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index d5db569f8b..19ece57ccb 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -380,7 +380,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ) if "config" not in kwargs_encoder: - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " @@ -402,7 +405,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ) if "config" not in kwargs_decoder: - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 6f9b547752..eb9b51fa7d 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -339,7 +339,10 @@ class VisionEncoderDecoderModel(PreTrainedModel): ) if "config" not in kwargs_encoder: - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " @@ -361,7 +364,10 @@ class VisionEncoderDecoderModel(PreTrainedModel): ) if "config" not in kwargs_decoder: - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index a0b35de461..9c4ab74c72 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -142,6 +142,48 @@ class EncoderDecoderMixin: outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) ) + def check_encoder_decoder_model_from_pretrained_using_model_paths( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: + encoder_model.save_pretrained(encoder_tmp_dirname) + decoder_model.save_pretrained(decoder_tmp_dirname) + model_kwargs = {"encoder_hidden_dropout_prob": 0.0} + + # BartConfig has no hidden_dropout_prob. + if not hasattr(decoder_config, "hidden_dropout_prob"): + model_kwargs["decoder_activation_function"] = "gelu" + else: + model_kwargs["decoder_hidden_dropout_prob"] = 0.0 + + enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained( + encoder_tmp_dirname, decoder_tmp_dirname, **model_kwargs + ) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + return_dict=True, + ) + + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + self.assertEqual( + outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) + ) + def check_encoder_decoder_model_from_pretrained( self, config, @@ -459,6 +501,10 @@ class EncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True) + def test_encoder_decoder_model_from_pretrained_using_model_paths(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_from_pretrained_using_model_paths(**input_ids_dict, return_dict=False) + def test_save_and_load_from_pretrained(self): input_ids_dict = self.prepare_config_and_inputs() self.check_save_and_load(**input_ids_dict)