From f7e2ac01ea4043cb967fe75789f8e4936324fa50 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 22:43:35 +0200 Subject: [PATCH] update barrier --- examples/run_classifier.py | 10 ++-------- examples/run_squad.py | 6 ++++-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 123efb9147..e708671e42 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -50,12 +50,6 @@ else: logger = logging.getLogger(__name__) -def barrier(): - t = torch.randn((), device='cuda') - torch.distributed.all_reduce(t) - torch.cuda.synchronize() - - def main(): parser = argparse.ArgumentParser() @@ -208,11 +202,11 @@ def main(): num_labels = len(label_list) if args.local_rank not in [-1, 0]: - barrier() # Make sure only the first process in distributed training will download model & vocab + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) if args.local_rank == 0: - barrier() + torch.distributed.barrier() if args.fp16: model.half() diff --git a/examples/run_squad.py b/examples/run_squad.py index f20dd9d356..0d0f52e760 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -183,10 +183,12 @@ def main(): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) - - # Prepare model model = BertForQuestionAnswering.from_pretrained(args.bert_model) + if args.local_rank == 0: + torch.distributed.barrier() if args.fp16: model.half()