update barrier

This commit is contained in:
thomwolf
2019-06-18 22:43:35 +02:00
parent 4d8c4337ae
commit f7e2ac01ea
2 changed files with 6 additions and 10 deletions

View File

@@ -183,10 +183,12 @@ def main():
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
# Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
if args.local_rank == 0:
torch.distributed.barrier()
if args.fp16:
model.half()