update barrier

This commit is contained in:
thomwolf
2019-06-18 22:43:35 +02:00
parent 4d8c4337ae
commit f7e2ac01ea
2 changed files with 6 additions and 10 deletions

View File

@@ -50,12 +50,6 @@ else:
logger = logging.getLogger(__name__)
def barrier():
t = torch.randn((), device='cuda')
torch.distributed.all_reduce(t)
torch.cuda.synchronize()
def main():
parser = argparse.ArgumentParser()
@@ -208,11 +202,11 @@ def main():
num_labels = len(label_list)
if args.local_rank not in [-1, 0]:
barrier() # Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
if args.local_rank == 0:
barrier()
torch.distributed.barrier()
if args.fp16:
model.half()