update barrier
This commit is contained in:
@@ -50,12 +50,6 @@ else:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def barrier():
|
|
||||||
t = torch.randn((), device='cuda')
|
|
||||||
torch.distributed.all_reduce(t)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@@ -208,11 +202,11 @@ def main():
|
|||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
|
|
||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
model.half()
|
model.half()
|
||||||
|
|||||||
@@ -183,10 +183,12 @@ def main():
|
|||||||
if not os.path.exists(args.output_dir):
|
if not os.path.exists(args.output_dir):
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
|
if args.local_rank not in [-1, 0]:
|
||||||
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
# Prepare model
|
|
||||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||||
|
if args.local_rank == 0:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
model.half()
|
model.half()
|
||||||
|
|||||||
Reference in New Issue
Block a user