load the pretrained weights for encoder-decoder
We currently save the pretrained_weights of the encoder and decoder in two separate directories `encoder` and `decoder`. However, for the `from_pretrained` function to operate with automodels we need to specify the type of model in the path to the weights. The path to the encoder/decoder weights is handled by the `PreTrainedEncoderDecoder` class in the `save_pretrained` function. Sice there is no easy way to infer the type of model that was initialized for the encoder and decoder we add a parameter `model_type` to the function. This is not an ideal solution as it is error prone, and the model type should be carried by the Model classes somehow. This is a temporary fix that should be changed before merging.
This commit is contained in:
committed by
Julien Chaumond
parent
07f4cd73f6
commit
1c71ecc880
@@ -117,8 +117,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("encoder_")
|
||||
and not argument.startswith("decoder_")
|
||||
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_decoder = kwargs_common.copy()
|
||||
kwargs_encoder = kwargs_common.copy()
|
||||
@@ -158,14 +157,27 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
|
||||
return model
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a Seq2Seq model and its configuration file in a format such
|
||||
def save_pretrained(self, save_directory, model_type="bert"):
|
||||
""" Save an EncoderDecoder model and its configuration file in a format such
|
||||
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
|
||||
|
||||
We save the encoder' and decoder's parameters in two separate directories.
|
||||
|
||||
If we want the weight loader to function we need to preprend the model
|
||||
type to the directories' names. As far as I know there is no simple way
|
||||
to infer the type of the model (except maybe by parsing the class'
|
||||
names, which is not very future-proof). For now, we ask the user to
|
||||
specify the model type explicitly when saving the weights.
|
||||
"""
|
||||
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
|
||||
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
|
||||
encoder_path = os.path.join(save_directory, "{}_encoder".format(model_type))
|
||||
if not os.path.exists(encoder_path):
|
||||
os.makedirs(encoder_path)
|
||||
self.encoder.save_pretrained(encoder_path)
|
||||
|
||||
decoder_path = os.path.join(save_directory, "{}_decoder".format(model_type))
|
||||
if not os.path.exists(decoder_path):
|
||||
os.makedirs(decoder_path)
|
||||
self.decoder.save_pretrained(decoder_path)
|
||||
|
||||
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||
""" The forward pass on a seq2eq depends what we are performing:
|
||||
@@ -193,8 +205,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("encoder_")
|
||||
and not argument.startswith("decoder_")
|
||||
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_decoder = kwargs_common.copy()
|
||||
kwargs_encoder = kwargs_common.copy()
|
||||
@@ -217,9 +228,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[
|
||||
0
|
||||
] # output the last layer hidden state
|
||||
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user