From b8ff56896ccbd27a54035a90a3bc278a44541a74 Mon Sep 17 00:00:00 2001 From: wangfei <1140554608@qq.com> Date: Fri, 16 Aug 2019 12:11:05 +0800 Subject: [PATCH] Fix bug of multi-gpu training in lm finetuning --- examples/lm_finetuning/finetune_on_pregenerated.py | 2 +- examples/lm_finetuning/simple_lm_finetuning.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/lm_finetuning/finetune_on_pregenerated.py b/examples/lm_finetuning/finetune_on_pregenerated.py index 9fcc5f2cb1..7c40342f18 100644 --- a/examples/lm_finetuning/finetune_on_pregenerated.py +++ b/examples/lm_finetuning/finetune_on_pregenerated.py @@ -320,7 +320,7 @@ def main(): global_step += 1 # Save a trained model - if n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 : + if args.local_rank == -1 or torch.distributed.get_rank() == 0: logging.info("** ** * Saving fine-tuned model ** ** * ") model.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) diff --git a/examples/lm_finetuning/simple_lm_finetuning.py b/examples/lm_finetuning/simple_lm_finetuning.py index ba5f832827..25333de0ed 100644 --- a/examples/lm_finetuning/simple_lm_finetuning.py +++ b/examples/lm_finetuning/simple_lm_finetuning.py @@ -507,7 +507,7 @@ def main(): if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) - if not os.path.exists(args.output_dir) and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 ): + if not os.path.exists(args.output_dir) and (args.local_rank == -1 or torch.distributed.get_rank() == 0): os.makedirs(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) @@ -608,7 +608,7 @@ def main(): global_step += 1 # Save a trained model - if args.do_train and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1): + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): logger.info("** ** * Saving fine - tuned model ** ** * ") model.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)