add distributed training

This commit is contained in:
thomwolf
2018-11-04 15:32:04 +01:00
parent 1ceac85e23
commit 965b2565a0
2 changed files with 26 additions and 10 deletions

View File

@@ -449,14 +449,15 @@ def main():
else:
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# print("Initializing the distributed backend: NCCL")
print("device", device, "n_gpu", n_gpu)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
if args.accumulate_gradients < 1:
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
args.accumulate_gradients))
args.batch_size = args.batch_size / args.accumulate_gradients
args.train_batch_size = args.train_batch_size / args.accumulate_gradients
random.seed(args.seed)
np.random.seed(args.seed)
@@ -502,7 +503,10 @@ def main():
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)
if n_gpu > 1:
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
no_decay = ['bias', 'gamma', 'beta']