From 1c71ecc880ae8f04c8462e1368dc0678fdb92fc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 31 Oct 2019 10:16:08 +0100 Subject: [PATCH] load the pretrained weights for encoder-decoder We currently save the pretrained_weights of the encoder and decoder in two separate directories `encoder` and `decoder`. However, for the `from_pretrained` function to operate with automodels we need to specify the type of model in the path to the weights. The path to the encoder/decoder weights is handled by the `PreTrainedEncoderDecoder` class in the `save_pretrained` function. Sice there is no easy way to infer the type of model that was initialized for the encoder and decoder we add a parameter `model_type` to the function. This is not an ideal solution as it is error prone, and the model type should be carried by the Model classes somehow. This is a temporary fix that should be changed before merging. --- examples/run_summarization_finetuning.py | 48 ++++++++++++++---------- transformers/modeling_encoder_decoder.py | 31 +++++++++------ 2 files changed, 49 insertions(+), 30 deletions(-) diff --git a/examples/run_summarization_finetuning.py b/examples/run_summarization_finetuning.py index f5604c2669..9c2c7769c9 100644 --- a/examples/run_summarization_finetuning.py +++ b/examples/run_summarization_finetuning.py @@ -328,6 +328,22 @@ def evaluate(args, model, tokenizer, prefix=""): return result +def save_model_checkpoints(args, model, tokenizer): + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + logger.info("Saving model checkpoint to %s", args.output_dir) + + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + model_to_save = ( + model.module if hasattr(model, "module") else model + ) # Take care of distributed/parallel training + model_to_save.save_pretrained(args.output_dir, model_type='bert') + tokenizer.save_pretrained(args.output_dir) + torch.save(args, os.path.join(args.output_dir, "training_arguments.bin")) + + def main(): parser = argparse.ArgumentParser() @@ -454,36 +470,30 @@ def main(): # Train the model model.to(args.device) if args.do_train: - global_step, tr_loss = train(args, model, tokenizer) + try: + global_step, tr_loss = train(args, model, tokenizer) + except KeyboardInterrupt: + response = input("You interrupted the training. Do you want to save the model checkpoints? [Y/n]") + if response.lower() in ["", "y", "yes"]: + save_model_checkpoints(args, model, tokenizer) + sys.exit(0) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - logger.info("Saving model checkpoint to %s", args.output_dir) - - # Save a trained model, configuration and tokenizer using `save_pretrained()`. - # They can then be reloaded using `from_pretrained()` - model_to_save = ( - model.module if hasattr(model, "module") else model - ) # Take care of distributed/parallel training - model_to_save.save_pretrained(args.output_dir) - tokenizer.save_pretrained(args.output_dir) - torch.save(args, os.path.join(args.output_dir, "training_arguments.bin")) + save_model_checkpoints(args, model, tokenizer) # Evaluate the model results = {} if args.do_evaluate: - checkpoints = [] + checkpoints = [args.output_dir] logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: - encoder_checkpoint = os.path.join(checkpoint, "encoder") - decoder_checkpoint = os.path.join(checkpoint, "decoder") + encoder_checkpoint = os.path.join(checkpoint, "bert_encoder") + decoder_checkpoint = os.path.join(checkpoint, "bert_decoder") model = PreTrainedEncoderDecoder.from_pretrained( encoder_checkpoint, decoder_checkpoint ) model.to(args.device) - results = "placeholder" + print("model loaded") return results diff --git a/transformers/modeling_encoder_decoder.py b/transformers/modeling_encoder_decoder.py index a884abd0a2..73322101d3 100644 --- a/transformers/modeling_encoder_decoder.py +++ b/transformers/modeling_encoder_decoder.py @@ -117,8 +117,7 @@ class PreTrainedEncoderDecoder(nn.Module): kwargs_common = { argument: value for argument, value in kwargs.items() - if not argument.startswith("encoder_") - and not argument.startswith("decoder_") + if not argument.startswith("encoder_") and not argument.startswith("decoder_") } kwargs_decoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy() @@ -158,14 +157,27 @@ class PreTrainedEncoderDecoder(nn.Module): return model - def save_pretrained(self, save_directory): - """ Save a Seq2Seq model and its configuration file in a format such + def save_pretrained(self, save_directory, model_type="bert"): + """ Save an EncoderDecoder model and its configuration file in a format such that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained` We save the encoder' and decoder's parameters in two separate directories. + + If we want the weight loader to function we need to preprend the model + type to the directories' names. As far as I know there is no simple way + to infer the type of the model (except maybe by parsing the class' + names, which is not very future-proof). For now, we ask the user to + specify the model type explicitly when saving the weights. """ - self.encoder.save_pretrained(os.path.join(save_directory, "encoder")) - self.decoder.save_pretrained(os.path.join(save_directory, "decoder")) + encoder_path = os.path.join(save_directory, "{}_encoder".format(model_type)) + if not os.path.exists(encoder_path): + os.makedirs(encoder_path) + self.encoder.save_pretrained(encoder_path) + + decoder_path = os.path.join(save_directory, "{}_decoder".format(model_type)) + if not os.path.exists(decoder_path): + os.makedirs(decoder_path) + self.decoder.save_pretrained(decoder_path) def forward(self, encoder_input_ids, decoder_input_ids, **kwargs): """ The forward pass on a seq2eq depends what we are performing: @@ -193,8 +205,7 @@ class PreTrainedEncoderDecoder(nn.Module): kwargs_common = { argument: value for argument, value in kwargs.items() - if not argument.startswith("encoder_") - and not argument.startswith("decoder_") + if not argument.startswith("encoder_") and not argument.startswith("decoder_") } kwargs_decoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy() @@ -217,9 +228,7 @@ class PreTrainedEncoderDecoder(nn.Module): encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) if encoder_hidden_states is None: encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[ - 0 - ] # output the last layer hidden state + encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state else: encoder_outputs = ()