From 29b7b30eaae52d97d4391353c562d8008f036b87 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 22:20:21 +0200 Subject: [PATCH] updating evaluation on a single gpu --- examples/run_classifier.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index c3a16f593d..27a17d7e31 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -306,10 +306,10 @@ def main(): logger.info(" Num steps = %d", num_train_optimization_steps) model.train() - for _ in trange(int(args.num_train_epochs), desc="Epoch"): + for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]): tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 - for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): + for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch @@ -367,21 +367,13 @@ def main(): # Load a trained model and vocabulary that you have fine-tuned model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels) tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + else: + model = BertForQuestionAnswering.from_pretrained(args.bert_model) - # Distributed/fp16/parallel settings (optional) - model.to(device) - if args.fp16: - model.half() - if args.local_rank != -1: - model = torch.nn.parallel.DistributedDataParallel(model, - device_ids=[args.local_rank], - output_device=args.local_rank, - find_unused_parameters=True) - elif n_gpu > 1: - model = torch.nn.DataParallel(model) + model.to(device) ### Evaluation - if args.do_eval: + if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = processor.get_dev_examples(args.data_dir) cached_eval_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format( list(filter(None, args.bert_model.split('/'))).pop(),