Debug run_squad_pytorch
This commit is contained in:
@@ -905,8 +905,10 @@ def main():
|
||||
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
|
||||
|
||||
unique_id = [int(eval_features[e.item()].unique_id) for e in example_index]
|
||||
start_logits = [x.item() for x in start_logits]
|
||||
end_logits = [x.item() for x in end_logits]
|
||||
#start_logits = [x.item() for x in start_logits]
|
||||
start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits]
|
||||
#end_logits = [x.item() for x in end_logits]
|
||||
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits]
|
||||
all_results.append(
|
||||
RawResult(
|
||||
unique_id=unique_id,
|
||||
|
||||
Reference in New Issue
Block a user