From e363e1d936e5bfa03e8ddcaa47348c16d42bc6ac Mon Sep 17 00:00:00 2001 From: Russell Klopfer Date: Mon, 7 Jun 2021 22:34:10 -0400 Subject: [PATCH] adds metric prefix. (#12057) * adds metric prefix. * update tests to include prefix --- examples/pytorch/question-answering/trainer_qa.py | 14 ++++++++++++-- examples/pytorch/test_examples.py | 6 +++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/question-answering/trainer_qa.py b/examples/pytorch/question-answering/trainer_qa.py index 702d8ac6ab..7f98eba236 100644 --- a/examples/pytorch/question-answering/trainer_qa.py +++ b/examples/pytorch/question-answering/trainer_qa.py @@ -31,7 +31,7 @@ class QuestionAnsweringTrainer(Trainer): self.eval_examples = eval_examples self.post_process_function = post_process_function - def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None): + def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"): eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_examples = self.eval_examples if eval_examples is None else eval_examples @@ -56,6 +56,11 @@ class QuestionAnsweringTrainer(Trainer): eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) metrics = self.compute_metrics(eval_preds) + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + self.log(metrics) else: metrics = {} @@ -67,7 +72,7 @@ class QuestionAnsweringTrainer(Trainer): self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) return metrics - def predict(self, predict_dataset, predict_examples, ignore_keys=None): + def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"): predict_dataloader = self.get_test_dataloader(predict_dataset) # Temporarily disable metric computation, we will do it in the loop here. @@ -92,4 +97,9 @@ class QuestionAnsweringTrainer(Trainer): predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") metrics = self.compute_metrics(predictions) + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index 717bca47c6..74f1cb28c1 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -213,7 +213,7 @@ class ExamplesTests(TestCasePlus): tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" - run_squad.py + run_qa.py --model_name_or_path bert-base-uncased --version_2_with_negative --train_file tests/fixtures/tests_samples/SQUAD/sample.json @@ -232,8 +232,8 @@ class ExamplesTests(TestCasePlus): with patch.object(sys, "argv", testargs): run_squad.main() result = get_results(tmp_dir) - self.assertGreaterEqual(result["f1"], 30) - self.assertGreaterEqual(result["exact"], 30) + self.assertGreaterEqual(result["eval_f1"], 30) + self.assertGreaterEqual(result["eval_exact"], 30) def test_run_swag(self): stream_handler = logging.StreamHandler(sys.stdout)