Kill the demon spawn
This commit is contained in:
@@ -248,7 +248,28 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
|
||||
result = SquadResult([to_list(output[i]) for output in outputs] + [unique_id])
|
||||
output = [to_list(output[i]) for output in outputs]
|
||||
|
||||
if len(output) >= 5:
|
||||
start_logits = output[0]
|
||||
start_top_index = output[1]
|
||||
end_logits = output[2]
|
||||
end_top_index = output[3],
|
||||
cls_logits = output[4]
|
||||
|
||||
result = SquadResult(
|
||||
unique_id, start_logits, end_logits,
|
||||
start_top_index=start_top_index,
|
||||
end_top_index=end_top_index,
|
||||
cls_logits=cls_logits
|
||||
)
|
||||
|
||||
else:
|
||||
start_logits, end_logits = output
|
||||
result = SquadResult(
|
||||
unique_id, start_logits, end_logits
|
||||
)
|
||||
|
||||
all_results.append(result)
|
||||
|
||||
evalTime = timeit.default_timer() - start_time
|
||||
|
||||
Reference in New Issue
Block a user