From b6933d76d27fd14a835b9ea095d56725c69f4796 Mon Sep 17 00:00:00 2001 From: Robert Stone Date: Wed, 3 May 2023 12:50:41 -0700 Subject: [PATCH] Tidy Pytorch GLUE benchmark example (#23134) Migration to Evaluate for metric is not quite complete --- examples/pytorch/text-classification/run_glue.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 1bb4c7bee7..dd81d535df 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -486,6 +486,8 @@ def main(): # Get the metric function if data_args.task_name is not None: metric = evaluate.load("glue", data_args.task_name) + elif is_regression: + metric = evaluate.load("mse") else: metric = evaluate.load("accuracy") @@ -494,15 +496,10 @@ def main(): def compute_metrics(p: EvalPrediction): preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) - if data_args.task_name is not None: - result = metric.compute(predictions=preds, references=p.label_ids) - if len(result) > 1: - result["combined_score"] = np.mean(list(result.values())).item() - return result - elif is_regression: - return {"mse": ((preds - p.label_ids) ** 2).mean().item()} - else: - return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + result = metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if # we already did the padding.