From f0aeb7a814289a64a5b22577415a0cfcde3c7870 Mon Sep 17 00:00:00 2001 From: zijunsun Date: Fri, 26 Jul 2019 15:23:29 +0800 Subject: [PATCH] =?UTF-8?q?multi-gpu=20training=20also=20should=20be=20aft?= =?UTF-8?q?er=20apex=20fp16=EF=BC=88squad=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/run_squad.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 36e03fb012..692cb4a20c 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -101,6 +101,10 @@ def train(args, train_dataset, model, tokenizer): raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + # multi-gpu training (should be after apex fp16 initialization) + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], @@ -457,8 +461,6 @@ def main(): torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab model.to(args.device) - if args.n_gpu > 1: - model = torch.nn.DataParallel(model) logger.info("Training/evaluation parameters %s", args)