From 7388c83b60e97c65e399fbb88b0da1ade9897dc0 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 16:32:49 +0200 Subject: [PATCH] update run_classifier for distributed eval --- examples/run_classifier.py | 47 +++++++++++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 6166cd7194..49fb3954b3 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -50,6 +50,15 @@ else: logger = logging.getLogger(__name__) +def average_distributed_scalar(scalar, args): + """ Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """ + if args.local_rank == -1: + return scalar + scalar_t = torch.tensor(scalar, dtype=torch.float, device=args.device) / torch.distributed.get_world_size() + torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM) + return scalar_t.item() + + def main(): parser = argparse.ArgumentParser() @@ -158,6 +167,7 @@ def main(): n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') + args.device = device logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -337,6 +347,8 @@ def main(): tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) tb_writer.add_scalar('loss', loss.item(), global_step) + ### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() + ### Example: if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Save a trained model, configuration and tokenizer model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self @@ -352,11 +364,21 @@ def main(): # Load a trained model and vocabulary that you have fine-tuned model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels) tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) - else: - model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) - model.to(device) - if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + # Distributed/fp16/parallel settings (optional) + model.to(device) + if args.fp16: + model.half() + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, + device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) + elif n_gpu > 1: + model = torch.nn.DataParallel(model) + + ### Evaluation + if args.do_eval: eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features( eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) @@ -374,7 +396,10 @@ def main(): eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) # Run prediction for full data - eval_sampler = SequentialSampler(eval_data) + if args.local_rank == -1: + eval_sampler = SequentialSampler(eval_data) + else: + eval_sampler = DistributedSampler(eval_data) # Note that this sampler samples randomly eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) model.eval() @@ -398,7 +423,7 @@ def main(): elif output_mode == "regression": loss_fct = MSELoss() tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1)) - + eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 if len(preds) == 0: @@ -414,6 +439,11 @@ def main(): elif output_mode == "regression": preds = np.squeeze(preds) result = compute_metrics(task_name, preds, all_label_ids.numpy()) + + if args.local_rank != -1: + # Average over distributed nodes if needed + result = {key: average_distributed_scalar(value, args) for key, value in result.items()} + loss = tr_loss/global_step if args.do_train else None result['eval_loss'] = eval_loss @@ -482,6 +512,11 @@ def main(): preds = preds[0] preds = np.argmax(preds, axis=1) result = compute_metrics(task_name, preds, all_label_ids.numpy()) + + if args.local_rank != -1: + # Average over distributed nodes if needed + result = {key: average_distributed_scalar(value, args) for key, value in result.items()} + loss = tr_loss/global_step if args.do_train else None result['eval_loss'] = eval_loss