From 5216607f8a67952395aca7b81b2dd9c7c60d2f4c Mon Sep 17 00:00:00 2001 From: Eldar Kurtic Date: Tue, 29 Mar 2022 16:38:14 +0200 Subject: [PATCH] [MNLI example] Prevent overwriting matched with mismatched metrics (#16475) * Prevent overwriting matched with mismatched metrics * Fix style --- examples/pytorch/text-classification/run_glue.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index da0bb551de..92c4d2b379 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -507,6 +507,7 @@ def main(): if data_args.task_name == "mnli": tasks.append("mnli-mm") eval_datasets.append(raw_datasets["validation_mismatched"]) + combined = {} for eval_dataset, task in zip(eval_datasets, tasks): metrics = trainer.evaluate(eval_dataset=eval_dataset) @@ -516,8 +517,13 @@ def main(): ) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + if task == "mnli-mm": + metrics = {k + "_mm": v for k, v in metrics.items()} + if "mnli" in task: + combined.update(metrics) + trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) + trainer.save_metrics("eval", combined if "mnli" in task else metrics) if training_args.do_predict: logger.info("*** Predict ***")