Add warning message for run_qa.py (#29867)
* improve: error message for best model metric * update: raise warning instead of error
This commit is contained in:
@@ -627,6 +627,14 @@ def main():
|
||||
references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples]
|
||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||
|
||||
if data_args.version_2_with_negative:
|
||||
accepted_best_metrics = ("exact", "f1", "HasAns_exact", "HasAns_f1")
|
||||
else:
|
||||
accepted_best_metrics = ("exact_match", "f1")
|
||||
|
||||
if training_args.load_best_model_at_end and training_args.metric_for_best_model not in accepted_best_metrics:
|
||||
warnings.warn(f"--metric_for_best_model should be set to one of {accepted_best_metrics}")
|
||||
|
||||
metric = evaluate.load(
|
||||
"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user