test barrier in distrib training

This commit is contained in:
thomwolf
2019-06-18 22:41:28 +02:00
parent 3359955622
commit 4d8c4337ae
3 changed files with 23 additions and 17 deletions

View File

@@ -50,6 +50,12 @@ else:
logger = logging.getLogger(__name__)
def barrier():
t = torch.randn((), device='cuda')
torch.distributed.all_reduce(t)
torch.cuda.synchronize()
def main():
parser = argparse.ArgumentParser()
@@ -201,10 +207,13 @@ def main():
label_list = processor.get_labels()
num_labels = len(label_list)
if args.local_rank not in [-1, 0]:
barrier() # Make sure only the first process in distributed training will download model & vocab
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
# Prepare model
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
if args.local_rank == 0:
barrier()
if args.fp16:
model.half()
model.to(device)