From 60a1bdcdacc4ec3d2ca6dad16e61b8ad8022285e Mon Sep 17 00:00:00 2001 From: yzy5630 Date: Wed, 17 Jul 2019 09:16:20 +0800 Subject: [PATCH] fix some errors for distributed lm_finetuning --- examples/lm_finetuning/simple_lm_finetuning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/lm_finetuning/simple_lm_finetuning.py b/examples/lm_finetuning/simple_lm_finetuning.py index 3008787cd1..0783c1bcc3 100644 --- a/examples/lm_finetuning/simple_lm_finetuning.py +++ b/examples/lm_finetuning/simple_lm_finetuning.py @@ -504,7 +504,7 @@ def main(): if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) - if not os.path.exists(args.output_dir): + if not os.path.exists(args.output_dir) and torch.distributed.get_rank() == 0: os.makedirs(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) @@ -613,11 +613,11 @@ def main(): global_step += 1 # Save a trained model - logger.info("** ** * Saving fine - tuned model ** ** * ") model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) - if args.do_train: + if args.do_train and torch.distributed.get_rank() == 0: + logger.info("** ** * Saving fine - tuned model ** ** * ") torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir)