From 1a8f61110e4a6b90f75b0ac40ea8cea2817da5ac Mon Sep 17 00:00:00 2001 From: Sebastian Date: Tue, 9 May 2023 15:20:10 +0200 Subject: [PATCH] fix: Update run_qa.py to work with deepset/germanquad (#23225) Call str on id to make sure any ints are converted into the expected format for squad datasets --- examples/pytorch/question-answering/run_qa.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index d3377611bd..b0dcf6e5d8 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -590,12 +590,12 @@ def main(): # Format the result to the format the metric expects. if data_args.version_2_with_negative: formatted_predictions = [ - {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + {"id": str(k), "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() ] else: - formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + formatted_predictions = [{"id": str(k), "prediction_text": v} for k, v in predictions.items()] - references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] + references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")