From a59abedfb5301ede6923ef0312de2ae5fa34fc97 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 12:06:26 +0200 Subject: [PATCH] DDP update --- examples/run_squad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 4704b7d4e8..6378f443e4 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -907,7 +907,7 @@ def main(): # except ImportError: # raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") - model = torch.nn.parallel.DistributedDataParallel(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) elif n_gpu > 1: model = torch.nn.DataParallel(model)