Fix Loading of Flax(Speech)EncoderDecoderModel kwargs from PreTrained Encoder-Decoder Checkpoints (#16056)
* Fix Loading of Flax(Speech)EncoderDecoderModel kwargs from PreTrained Encoder-Decoder Checkpoints * change wording
This commit is contained in:
@@ -196,6 +196,51 @@ class FlaxEncoderDecoderMixin:
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 4e-2)
|
||||
|
||||
def check_encoder_decoder_model_from_encoder_decoder_pretrained(
|
||||
self,
|
||||
config,
|
||||
inputs,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
# assert that loading encoder and decoder models from configs has been correctly executed
|
||||
self.assertEqual(config.add_adapter, encoder_model.config.add_adapter)
|
||||
self.assertEqual(decoder_config.use_cache, decoder_model.config.use_cache)
|
||||
|
||||
with tempfile.TemporaryDirectory() as enc_tmpdir:
|
||||
with tempfile.TemporaryDirectory() as dec_tmpdir:
|
||||
encoder_model.save_pretrained(enc_tmpdir)
|
||||
decoder_model.save_pretrained(dec_tmpdir)
|
||||
# load a model from pretrained encoder and decoder checkpoints, setting one encoder and one decoder kwarg opposite to that specified in their respective configs
|
||||
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_pretrained_model_name_or_path=enc_tmpdir,
|
||||
decoder_pretrained_model_name_or_path=dec_tmpdir,
|
||||
encoder_add_adapter=not config.add_adapter,
|
||||
decoder_use_cache=not decoder_config.use_cache,
|
||||
)
|
||||
|
||||
# assert that setting encoder and decoder kwargs opposite to those in the configs has correctly been applied
|
||||
self.assertNotEqual(config.add_adapter, enc_dec_model.config.encoder.add_adapter)
|
||||
self.assertNotEqual(decoder_config.use_cache, enc_dec_model.config.decoder.use_cache)
|
||||
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
@@ -441,6 +486,10 @@ class FlaxEncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_save_and_load(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_encoder_decoder_pretrained(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_output_attentions(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||
|
||||
Reference in New Issue
Block a user