Merge pull request #1804 from ronakice/master

fix multi-gpu eval in torch examples
This commit is contained in:
Thomas Wolf
2019-11-14 22:24:05 +01:00
committed by GitHub
6 changed files with 24 additions and 0 deletions

View File

@@ -224,6 +224,10 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))