Merge pull request #1804 from ronakice/master

fix multi-gpu eval in torch examples
This commit is contained in:
Thomas Wolf
2019-11-14 22:24:05 +01:00
committed by GitHub
6 changed files with 24 additions and 0 deletions

View File

@@ -217,6 +217,10 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(dataset))