Add warning message for run_qa.py (#29867)
* improve: error message for best model metric * update: raise warning instead of error
This commit is contained in:
@@ -627,6 +627,14 @@ def main():
|
|||||||
references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples]
|
references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples]
|
||||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||||
|
|
||||||
|
if data_args.version_2_with_negative:
|
||||||
|
accepted_best_metrics = ("exact", "f1", "HasAns_exact", "HasAns_f1")
|
||||||
|
else:
|
||||||
|
accepted_best_metrics = ("exact_match", "f1")
|
||||||
|
|
||||||
|
if training_args.load_best_model_at_end and training_args.metric_for_best_model not in accepted_best_metrics:
|
||||||
|
warnings.warn(f"--metric_for_best_model should be set to one of {accepted_best_metrics}")
|
||||||
|
|
||||||
metric = evaluate.load(
|
metric = evaluate.load(
|
||||||
"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
|
"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user