clean for release

This commit is contained in:
Rémi Louf
2019-12-06 22:01:48 +01:00
committed by Julien Chaumond
parent 2a64107e44
commit f7eba09007
8 changed files with 49 additions and 376 deletions

View File

@@ -117,7 +117,8 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
if not argument.startswith("encoder_")
and not argument.startswith("decoder_")
}
kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy()
@@ -157,27 +158,14 @@ class PreTrainedEncoderDecoder(nn.Module):
return model
def save_pretrained(self, save_directory, model_type="bert"):
""" Save an EncoderDecoder model and its configuration file in a format such
def save_pretrained(self, save_directory):
""" Save a Seq2Seq model and its configuration file in a format such
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
We save the encoder' and decoder's parameters in two separate directories.
If we want the weight loader to function we need to preprend the model
type to the directories' names. As far as I know there is no simple way
to infer the type of the model (except maybe by parsing the class'
names, which is not very future-proof). For now, we ask the user to
specify the model type explicitly when saving the weights.
"""
encoder_path = os.path.join(save_directory, "{}_encoder".format(model_type))
if not os.path.exists(encoder_path):
os.makedirs(encoder_path)
self.encoder.save_pretrained(encoder_path)
decoder_path = os.path.join(save_directory, "{}_decoder".format(model_type))
if not os.path.exists(decoder_path):
os.makedirs(decoder_path)
self.decoder.save_pretrained(decoder_path)
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
""" The forward pass on a seq2eq depends what we are performing:
@@ -205,7 +193,8 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
if not argument.startswith("encoder_")
and not argument.startswith("decoder_")
}
kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy()
@@ -228,7 +217,9 @@ class PreTrainedEncoderDecoder(nn.Module):
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
encoder_hidden_states = encoder_outputs[
0
] # output the last layer hidden state
else:
encoder_outputs = ()