Debug run_squad_pytorch
This commit is contained in:
@@ -909,11 +909,21 @@ def main():
|
|||||||
start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits]
|
start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits]
|
||||||
#end_logits = [x.item() for x in end_logits]
|
#end_logits = [x.item() for x in end_logits]
|
||||||
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits]
|
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits]
|
||||||
all_results.append(
|
for idx, i in enumerate(unique_id):
|
||||||
RawResult(
|
s = start_logits[idx]
|
||||||
unique_id=unique_id,
|
e = end_logits[idx]
|
||||||
start_logits=start_logits,
|
all_results.append(
|
||||||
end_logits=end_logits))
|
RawResult(
|
||||||
|
unique_id=i,
|
||||||
|
start_logits=s,
|
||||||
|
end_logits=e
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# all_results.append(
|
||||||
|
# RawResult(
|
||||||
|
# unique_id=unique_id,
|
||||||
|
# start_logits=start_logits,
|
||||||
|
# end_logits=end_logits))
|
||||||
|
|
||||||
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
|
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
|
||||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
|
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
|
||||||
|
|||||||
Reference in New Issue
Block a user