From ee29871f8dedc0452749b81c50b099500870e1ee Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 11:07:32 -0400 Subject: [PATCH] Debug run_squad_pytorch --- run_squad_pytorch.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index 38094e2ee2..100909b821 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -909,11 +909,21 @@ def main(): start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits] #end_logits = [x.item() for x in end_logits] end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits] - all_results.append( - RawResult( - unique_id=unique_id, - start_logits=start_logits, - end_logits=end_logits)) + for idx, i in enumerate(unique_id): + s = start_logits[idx] + e = end_logits[idx] + all_results.append( + RawResult( + unique_id=i, + start_logits=s, + end_logits=e + ) + ) + # all_results.append( + # RawResult( + # unique_id=unique_id, + # start_logits=start_logits, + # end_logits=end_logits)) output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")