From 3e847449adaeaab6963462c32539f5f17daecf6f Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 16:53:31 +0200 Subject: [PATCH] fix out_label_ids --- examples/run_classifier.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index d945b0dfc8..b1749ec8be 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -420,6 +420,7 @@ def main(): eval_loss = 0 nb_eval_steps = 0 preds = [] + out_label_ids = [] for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): input_ids = input_ids.to(device) @@ -442,9 +443,12 @@ def main(): nb_eval_steps += 1 if len(preds) == 0: preds.append(logits.detach().cpu().numpy()) + out_label_ids.append(label_ids.detach().cpu().numpy()) else: preds[0] = np.append( preds[0], logits.detach().cpu().numpy(), axis=0) + out_label_ids[0] = np.append( + out_label_ids[0], label_ids.detach().cpu().numpy(), axis=0) eval_loss = eval_loss / nb_eval_steps preds = preds[0] @@ -452,7 +456,7 @@ def main(): preds = np.argmax(preds, axis=1) elif output_mode == "regression": preds = np.squeeze(preds) - result = compute_metrics(task_name, preds, all_label_ids.numpy()) + result = compute_metrics(task_name, preds, out_label_ids.numpy()) if args.local_rank != -1: # Average over distributed nodes if needed @@ -501,6 +505,7 @@ def main(): eval_loss = 0 nb_eval_steps = 0 preds = [] + out_label_ids = [] for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): input_ids = input_ids.to(device) @@ -518,14 +523,18 @@ def main(): nb_eval_steps += 1 if len(preds) == 0: preds.append(logits.detach().cpu().numpy()) + out_label_ids.append(label_ids.detach().cpu().numpy()) else: preds[0] = np.append( preds[0], logits.detach().cpu().numpy(), axis=0) + out_label_ids[0] = np.append( + out_label_ids[0], label_ids.detach().cpu().numpy(), axis=0) + eval_loss = eval_loss / nb_eval_steps preds = preds[0] preds = np.argmax(preds, axis=1) - result = compute_metrics(task_name, preds, all_label_ids.numpy()) + result = compute_metrics(task_name, preds, out_label_ids.numpy()) if args.local_rank != -1: # Average over distributed nodes if needed