From 2ef5e0de871e108cbd8d52c26fb47efeca6ee087 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 12:03:13 +0200 Subject: [PATCH] switch to pytorch DistributedDataParallel --- examples/run_squad.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index d4bfc02556..4704b7d4e8 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -902,12 +902,12 @@ def main(): model.half() model.to(device) if args.local_rank != -1: - try: - from apex.parallel import DistributedDataParallel as DDP - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + # try: + # from apex.parallel import DistributedDataParallel as DDP + # except ImportError: + # raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") - model = DDP(model) + model = torch.nn.parallel.DistributedDataParallel(model) elif n_gpu > 1: model = torch.nn.DataParallel(model)