diff --git a/examples/run_glue.py b/examples/run_glue.py index dc8f66434b..f3c31b0c06 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -310,7 +310,7 @@ def evaluate(args, model, tokenizer, prefix=""): eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # multi-gpu eval - if args.n_gpu > 1: + if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): model = torch.nn.DataParallel(model) # Eval!