From 123da5a2fa6a5b2d2e4e48b82c86fc13d547314e Mon Sep 17 00:00:00 2001 From: yzy5630 Date: Wed, 17 Jul 2019 09:56:07 +0800 Subject: [PATCH] fix errors for lm_finetuning examples --- .../lm_finetuning/finetune_on_pregenerated.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/lm_finetuning/finetune_on_pregenerated.py b/examples/lm_finetuning/finetune_on_pregenerated.py index fe958345d1..fd38025424 100644 --- a/examples/lm_finetuning/finetune_on_pregenerated.py +++ b/examples/lm_finetuning/finetune_on_pregenerated.py @@ -325,15 +325,16 @@ def main(): global_step += 1 # Save a trained model - logging.info("** ** * Saving fine-tuned model ** ** * ") - model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + if torch.distributed.get_rank() == 0: + logging.info("** ** * Saving fine-tuned model ** ** * ") + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) - output_config_file = os.path.join(args.output_dir, CONFIG_NAME) + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) - torch.save(model_to_save.state_dict(), output_model_file) - model_to_save.config.to_json_file(output_config_file) - tokenizer.save_vocabulary(args.output_dir) + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(args.output_dir) if __name__ == '__main__':