From 3359955622a10b46e0360e5040eeb7e3725eecfc Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 22:23:10 +0200 Subject: [PATCH] updating run_classif --- examples/run_classifier.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 27a17d7e31..47cf43e17c 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -50,15 +50,6 @@ else: logger = logging.getLogger(__name__) -def average_distributed_scalar(scalar, args): - """ Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """ - if args.local_rank == -1: - return scalar - scalar_t = torch.tensor(scalar, dtype=torch.float, device=args.device) / torch.distributed.get_world_size() - torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM) - return scalar_t.item() - - def main(): parser = argparse.ArgumentParser() @@ -368,7 +359,7 @@ def main(): model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels) tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) else: - model = BertForQuestionAnswering.from_pretrained(args.bert_model) + model = BertForSequenceClassification.from_pretrained(args.bert_model) model.to(device) @@ -453,10 +444,6 @@ def main(): preds = np.squeeze(preds) result = compute_metrics(task_name, preds, out_label_ids) - if args.local_rank != -1: - # Average over distributed nodes if needed - result = {key: average_distributed_scalar(value, args) for key, value in result.items()} - loss = tr_loss/global_step if args.do_train else None result['eval_loss'] = eval_loss @@ -510,10 +497,10 @@ def main(): with torch.no_grad(): logits = model(input_ids, segment_ids, input_mask, labels=None) - + loss_fct = CrossEntropyLoss() tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) - + eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 if len(preds) == 0: @@ -530,10 +517,6 @@ def main(): preds = np.argmax(preds, axis=1) result = compute_metrics(task_name, preds, out_label_ids) - if args.local_rank != -1: - # Average over distributed nodes if needed - result = {key: average_distributed_scalar(value, args) for key, value in result.items()} - loss = tr_loss/global_step if args.do_train else None result['eval_loss'] = eval_loss