From 156d30da94cc62f745ad9577655a6633240d3c23 Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Sat, 30 Mar 2024 09:02:31 -0700 Subject: [PATCH] Add warning message for `run_qa.py` (#29867) * improve: error message for best model metric * update: raise warning instead of error --- examples/pytorch/question-answering/run_qa.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index e605a06492..e2bb8d5e68 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -627,6 +627,14 @@ def main(): references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) + if data_args.version_2_with_negative: + accepted_best_metrics = ("exact", "f1", "HasAns_exact", "HasAns_f1") + else: + accepted_best_metrics = ("exact_match", "f1") + + if training_args.load_best_model_at_end and training_args.metric_for_best_model not in accepted_best_metrics: + warnings.warn(f"--metric_for_best_model should be set to one of {accepted_best_metrics}") + metric = evaluate.load( "squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir )