fix multi-gpu eval

This commit is contained in:
ronakice
2019-11-12 05:55:11 -05:00
parent 8aba81a0b6
commit 2e31176557
6 changed files with 24 additions and 0 deletions

View File

@@ -275,6 +275,10 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size
)
# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)