From aae4edb5f0d0d5d4caa3795b700340dad3ddd6ae Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 21 Sep 2020 11:37:00 +0200 Subject: [PATCH] Addressing review comment --- examples/text-classification/run_glue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 39cd6ab403..9de123c55d 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -150,7 +150,7 @@ def main(): def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: def compute_metrics_fn(p: EvalPrediction): - preds = p.predictions[0] if type(p.predictions) == tuple else p.predictions + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions if output_mode == "classification": preds = np.argmax(preds, axis=1) else: # regression