From 101eabff90ab1470827835708181b75ef4b7077c Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 10:44:08 -0400 Subject: [PATCH] Debug run_squad_pytorch --- run_squad_pytorch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index 38d8447fa4..38094e2ee2 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -905,8 +905,10 @@ def main(): start_logits, end_logits = model(input_ids, segment_ids, input_mask) unique_id = [int(eval_features[e.item()].unique_id) for e in example_index] - start_logits = [x.item() for x in start_logits] - end_logits = [x.item() for x in end_logits] + #start_logits = [x.item() for x in start_logits] + start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits] + #end_logits = [x.item() for x in end_logits] + end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits] all_results.append( RawResult( unique_id=unique_id,