Kill the demon spawn

This commit is contained in:
LysandreJik
2019-12-04 15:42:29 -05:00
parent bf119c0568
commit cca75e7884
2 changed files with 34 additions and 64 deletions

View File

@@ -248,7 +248,28 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
result = SquadResult([to_list(output[i]) for output in outputs] + [unique_id])
output = [to_list(output[i]) for output in outputs]
if len(output) >= 5:
start_logits = output[0]
start_top_index = output[1]
end_logits = output[2]
end_top_index = output[3],
cls_logits = output[4]
result = SquadResult(
unique_id, start_logits, end_logits,
start_top_index=start_top_index,
end_top_index=end_top_index,
cls_logits=cls_logits
)
else:
start_logits, end_logits = output
result = SquadResult(
unique_id, start_logits, end_logits
)
all_results.append(result)
evalTime = timeit.default_timer() - start_time