From fd7b6a5274222add70121002560636fca2d1587b Mon Sep 17 00:00:00 2001 From: Wissam Antoun Date: Fri, 18 Dec 2020 14:53:23 +0200 Subject: [PATCH] fixed JSON error in run_qa with fp16 (#9186) --- examples/question-answering/utils_qa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/question-answering/utils_qa.py b/examples/question-answering/utils_qa.py index c26d37f9db..aad5deccf9 100644 --- a/examples/question-answering/utils_qa.py +++ b/examples/question-answering/utils_qa.py @@ -206,7 +206,7 @@ def postprocess_qa_predictions( # Make `predictions` JSON-serializable by casting np.float back to float. all_nbest_json[example["id"]] = [ - {k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()} + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} for pred in predictions ] @@ -394,7 +394,7 @@ def postprocess_qa_predictions_with_beam_search( # Make `predictions` JSON-serializable by casting np.float back to float. all_nbest_json[example["id"]] = [ - {k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()} + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} for pred in predictions ]