Fix usage of additional kwargs in from_encoder_decoder_pretrained in encoder-decoder models (#15056)
* [EncoderDecoder] Add test for usage of extra kwargs * [EncoderDecoder] Fix usage of extra kwargs in from pretrained * [EncoderDecoder] apply suggested changes (passing **kwargs_encoder) * [EncoderDecoder] create new test function and make sure it passes Co-authored-by: jonas <jsnfly@gmx.de>
This commit is contained in:
@@ -371,7 +371,10 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_encoder:
|
if "config" not in kwargs_encoder:
|
||||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
|
||||||
|
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
|
||||||
|
)
|
||||||
|
|
||||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||||
@@ -393,7 +396,10 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_decoder:
|
if "config" not in kwargs_decoder:
|
||||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
|
||||||
|
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
|
||||||
|
)
|
||||||
|
|
||||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||||
|
|||||||
@@ -380,7 +380,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_encoder:
|
if "config" not in kwargs_encoder:
|
||||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
|
||||||
|
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
|
||||||
|
)
|
||||||
|
|
||||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||||
@@ -402,7 +405,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_decoder:
|
if "config" not in kwargs_decoder:
|
||||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
|
||||||
|
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
|
||||||
|
)
|
||||||
|
|
||||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||||
|
|||||||
@@ -339,7 +339,10 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_encoder:
|
if "config" not in kwargs_encoder:
|
||||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
|
||||||
|
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
|
||||||
|
)
|
||||||
|
|
||||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||||
@@ -361,7 +364,10 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "config" not in kwargs_decoder:
|
if "config" not in kwargs_decoder:
|
||||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
|
||||||
|
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
|
||||||
|
)
|
||||||
|
|
||||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||||
|
|||||||
@@ -142,6 +142,48 @@ class EncoderDecoderMixin:
|
|||||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_encoder_decoder_model_from_pretrained_using_model_paths(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
|
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||||
|
encoder_model.save_pretrained(encoder_tmp_dirname)
|
||||||
|
decoder_model.save_pretrained(decoder_tmp_dirname)
|
||||||
|
model_kwargs = {"encoder_hidden_dropout_prob": 0.0}
|
||||||
|
|
||||||
|
# BartConfig has no hidden_dropout_prob.
|
||||||
|
if not hasattr(decoder_config, "hidden_dropout_prob"):
|
||||||
|
model_kwargs["decoder_activation_function"] = "gelu"
|
||||||
|
else:
|
||||||
|
model_kwargs["decoder_hidden_dropout_prob"] = 0.0
|
||||||
|
|
||||||
|
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
|
encoder_tmp_dirname, decoder_tmp_dirname, **model_kwargs
|
||||||
|
)
|
||||||
|
enc_dec_model.to(torch_device)
|
||||||
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||||
|
)
|
||||||
|
|
||||||
def check_encoder_decoder_model_from_pretrained(
|
def check_encoder_decoder_model_from_pretrained(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -459,6 +501,10 @@ class EncoderDecoderMixin:
|
|||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
|
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
|
||||||
|
|
||||||
|
def test_encoder_decoder_model_from_pretrained_using_model_paths(self):
|
||||||
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
|
self.check_encoder_decoder_model_from_pretrained_using_model_paths(**input_ids_dict, return_dict=False)
|
||||||
|
|
||||||
def test_save_and_load_from_pretrained(self):
|
def test_save_and_load_from_pretrained(self):
|
||||||
input_ids_dict = self.prepare_config_and_inputs()
|
input_ids_dict = self.prepare_config_and_inputs()
|
||||||
self.check_save_and_load(**input_ids_dict)
|
self.check_save_and_load(**input_ids_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user