adds metric prefix. (#12057)
* adds metric prefix. * update tests to include prefix
This commit is contained in:
@@ -31,7 +31,7 @@ class QuestionAnsweringTrainer(Trainer):
|
|||||||
self.eval_examples = eval_examples
|
self.eval_examples = eval_examples
|
||||||
self.post_process_function = post_process_function
|
self.post_process_function = post_process_function
|
||||||
|
|
||||||
def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None):
|
def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
|
||||||
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
||||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||||
eval_examples = self.eval_examples if eval_examples is None else eval_examples
|
eval_examples = self.eval_examples if eval_examples is None else eval_examples
|
||||||
@@ -56,6 +56,11 @@ class QuestionAnsweringTrainer(Trainer):
|
|||||||
eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
|
eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
|
||||||
metrics = self.compute_metrics(eval_preds)
|
metrics = self.compute_metrics(eval_preds)
|
||||||
|
|
||||||
|
# Prefix all keys with metric_key_prefix + '_'
|
||||||
|
for key in list(metrics.keys()):
|
||||||
|
if not key.startswith(f"{metric_key_prefix}_"):
|
||||||
|
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||||
|
|
||||||
self.log(metrics)
|
self.log(metrics)
|
||||||
else:
|
else:
|
||||||
metrics = {}
|
metrics = {}
|
||||||
@@ -67,7 +72,7 @@ class QuestionAnsweringTrainer(Trainer):
|
|||||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def predict(self, predict_dataset, predict_examples, ignore_keys=None):
|
def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
|
||||||
predict_dataloader = self.get_test_dataloader(predict_dataset)
|
predict_dataloader = self.get_test_dataloader(predict_dataset)
|
||||||
|
|
||||||
# Temporarily disable metric computation, we will do it in the loop here.
|
# Temporarily disable metric computation, we will do it in the loop here.
|
||||||
@@ -92,4 +97,9 @@ class QuestionAnsweringTrainer(Trainer):
|
|||||||
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
|
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
|
||||||
metrics = self.compute_metrics(predictions)
|
metrics = self.compute_metrics(predictions)
|
||||||
|
|
||||||
|
# Prefix all keys with metric_key_prefix + '_'
|
||||||
|
for key in list(metrics.keys()):
|
||||||
|
if not key.startswith(f"{metric_key_prefix}_"):
|
||||||
|
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||||
|
|
||||||
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
|
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
|
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
testargs = f"""
|
testargs = f"""
|
||||||
run_squad.py
|
run_qa.py
|
||||||
--model_name_or_path bert-base-uncased
|
--model_name_or_path bert-base-uncased
|
||||||
--version_2_with_negative
|
--version_2_with_negative
|
||||||
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
|
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||||
@@ -232,8 +232,8 @@ class ExamplesTests(TestCasePlus):
|
|||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_squad.main()
|
run_squad.main()
|
||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
self.assertGreaterEqual(result["f1"], 30)
|
self.assertGreaterEqual(result["eval_f1"], 30)
|
||||||
self.assertGreaterEqual(result["exact"], 30)
|
self.assertGreaterEqual(result["eval_exact"], 30)
|
||||||
|
|
||||||
def test_run_swag(self):
|
def test_run_swag(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
|||||||
Reference in New Issue
Block a user