diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py index 28faccd322..efde18e13c 100644 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -822,7 +822,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): ) if "config" not in kwargs_encoder: - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " @@ -846,7 +848,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): ) if "config" not in kwargs_decoder: - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py index a685c13463..6d0b7a2004 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -835,7 +835,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): ) if "config" not in kwargs_encoder: - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " @@ -859,7 +861,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): ) if "config" not in kwargs_decoder: - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " diff --git a/tests/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/encoder_decoder/test_modeling_flax_encoder_decoder.py index e6f0a49c16..d0ab1a25d1 100644 --- a/tests/encoder_decoder/test_modeling_flax_encoder_decoder.py +++ b/tests/encoder_decoder/test_modeling_flax_encoder_decoder.py @@ -160,6 +160,51 @@ class FlaxEncoderDecoderMixin: max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) + def check_encoder_decoder_model_from_encoder_decoder_pretrained( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + # assert that model attributes match those of configs + self.assertEqual(config.use_cache, encoder_model.config.use_cache) + self.assertEqual(decoder_config.use_cache, decoder_model.config.use_cache) + + with tempfile.TemporaryDirectory() as enc_tmpdir: + with tempfile.TemporaryDirectory() as dec_tmpdir: + encoder_model.save_pretrained(enc_tmpdir) + decoder_model.save_pretrained(dec_tmpdir) + # load a model from pretrained encoder and decoder checkpoints, setting one encoder and one decoder kwarg opposite to that specified in their respective configs + enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained( + encoder_pretrained_model_name_or_path=enc_tmpdir, + decoder_pretrained_model_name_or_path=dec_tmpdir, + encoder_use_cache=not config.use_cache, + decoder_use_cache=not decoder_config.use_cache, + ) + + # assert that setting encoder and decoder kwargs opposite to those in the configs has correctly been applied + self.assertNotEqual(config.use_cache, enc_dec_model.config.encoder.use_cache) + self.assertNotEqual(decoder_config.use_cache, enc_dec_model.config.decoder.use_cache) + + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_hidden_states=True, + return_dict=True, + ) + + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + def check_encoder_decoder_model_output_attentions( self, config, @@ -326,6 +371,10 @@ class FlaxEncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_save_and_load(**input_ids_dict) + def test_encoder_decoder_model_from_encoder_decoder_pretrained(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_from_encoder_decoder_pretrained(**input_ids_dict) + def test_encoder_decoder_model_output_attentions(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_output_attentions(**input_ids_dict) diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 981f54aad4..4ceea974f3 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -196,6 +196,51 @@ class FlaxEncoderDecoderMixin: max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 4e-2) + def check_encoder_decoder_model_from_encoder_decoder_pretrained( + self, + config, + inputs, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + # assert that loading encoder and decoder models from configs has been correctly executed + self.assertEqual(config.add_adapter, encoder_model.config.add_adapter) + self.assertEqual(decoder_config.use_cache, decoder_model.config.use_cache) + + with tempfile.TemporaryDirectory() as enc_tmpdir: + with tempfile.TemporaryDirectory() as dec_tmpdir: + encoder_model.save_pretrained(enc_tmpdir) + decoder_model.save_pretrained(dec_tmpdir) + # load a model from pretrained encoder and decoder checkpoints, setting one encoder and one decoder kwarg opposite to that specified in their respective configs + enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + encoder_pretrained_model_name_or_path=enc_tmpdir, + decoder_pretrained_model_name_or_path=dec_tmpdir, + encoder_add_adapter=not config.add_adapter, + decoder_use_cache=not decoder_config.use_cache, + ) + + # assert that setting encoder and decoder kwargs opposite to those in the configs has correctly been applied + self.assertNotEqual(config.add_adapter, enc_dec_model.config.encoder.add_adapter) + self.assertNotEqual(decoder_config.use_cache, enc_dec_model.config.decoder.use_cache) + + outputs_encoder_decoder = enc_dec_model( + inputs=inputs, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_hidden_states=True, + return_dict=True, + ) + + self.assertEqual( + outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) + ) + def check_encoder_decoder_model_output_attentions( self, config, @@ -441,6 +486,10 @@ class FlaxEncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_save_and_load(**input_ids_dict) + def test_encoder_decoder_model_from_encoder_decoder_pretrained(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_from_encoder_decoder_pretrained(**input_ids_dict) + def test_encoder_decoder_model_output_attentions(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_output_attentions(**input_ids_dict)