From 43b9d93875cbf6756baf402a4720ca23d8c75015 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 21 Sep 2020 15:04:20 +0530 Subject: [PATCH] [example/glue] fix compute_metrics_fn for bart like models (#7248) * fix compute_metrics_fn * p.predictions -> preds * apply suggestions --- examples/text-classification/run_glue.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 61776b2056..39cd6ab403 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -150,10 +150,11 @@ def main(): def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: def compute_metrics_fn(p: EvalPrediction): + preds = p.predictions[0] if type(p.predictions) == tuple else p.predictions if output_mode == "classification": - preds = np.argmax(p.predictions, axis=1) - elif output_mode == "regression": - preds = np.squeeze(p.predictions) + preds = np.argmax(preds, axis=1) + else: # regression + preds = np.squeeze(preds) return glue_compute_metrics(task_name, preds, p.label_ids) return compute_metrics_fn