Debug run_squad_pytorch

This commit is contained in:
VictorSanh
2018-11-02 10:44:08 -04:00
parent bb0a510330
commit 101eabff90

View File

@@ -905,8 +905,10 @@ def main():
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
unique_id = [int(eval_features[e.item()].unique_id) for e in example_index]
start_logits = [x.item() for x in start_logits]
end_logits = [x.item() for x in end_logits]
#start_logits = [x.item() for x in start_logits]
start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits]
#end_logits = [x.item() for x in end_logits]
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits]
all_results.append(
RawResult(
unique_id=unique_id,