test barrier in distrib training
This commit is contained in:
@@ -50,6 +50,12 @@ else:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def barrier():
|
||||
t = torch.randn((), device='cuda')
|
||||
torch.distributed.all_reduce(t)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -201,10 +207,13 @@ def main():
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
if args.local_rank not in [-1, 0]:
|
||||
barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
# Prepare model
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||
if args.local_rank == 0:
|
||||
barrier()
|
||||
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
|
||||
Reference in New Issue
Block a user