From 2e31176557d381171d44b5b51b72411b4c2e0601 Mon Sep 17 00:00:00 2001 From: ronakice Date: Tue, 12 Nov 2019 05:55:11 -0500 Subject: [PATCH] fix multi-gpu eval --- examples/run_glue.py | 4 ++++ examples/run_lm_finetuning.py | 4 ++++ examples/run_multiple_choice.py | 4 ++++ examples/run_ner.py | 4 ++++ examples/run_squad.py | 4 ++++ examples/run_summarization_finetuning.py | 4 ++++ 6 files changed, 24 insertions(+) diff --git a/examples/run_glue.py b/examples/run_glue.py index 1558a812c3..f82e589301 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -224,6 +224,10 @@ def evaluate(args, model, tokenizer, prefix=""): eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + # multi-gpu eval + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 2044cfe9e8..d9ee2fdb2b 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -300,6 +300,10 @@ def evaluate(args, model, tokenizer, prefix=""): eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) diff --git a/examples/run_multiple_choice.py b/examples/run_multiple_choice.py index 638bbe74f1..c9e13e198d 100644 --- a/examples/run_multiple_choice.py +++ b/examples/run_multiple_choice.py @@ -229,6 +229,10 @@ def evaluate(args, model, tokenizer, prefix="", test=False): eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) diff --git a/examples/run_ner.py b/examples/run_ner.py index b35d8298fe..c12709e37b 100644 --- a/examples/run_ner.py +++ b/examples/run_ner.py @@ -191,6 +191,10 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix="" eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Eval! logger.info("***** Running evaluation %s *****", prefix) logger.info(" Num examples = %d", len(eval_dataset)) diff --git a/examples/run_squad.py b/examples/run_squad.py index d9dc2abfde..ad4656462d 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -217,6 +217,10 @@ def evaluate(args, model, tokenizer, prefix=""): eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(dataset)) diff --git a/examples/run_summarization_finetuning.py b/examples/run_summarization_finetuning.py index 448505c727..f5604c2669 100644 --- a/examples/run_summarization_finetuning.py +++ b/examples/run_summarization_finetuning.py @@ -275,6 +275,10 @@ def evaluate(args, model, tokenizer, prefix=""): eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size ) + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size)