Kill the demon spawn
This commit is contained in:
@@ -248,7 +248,28 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
|
||||
result = SquadResult([to_list(output[i]) for output in outputs] + [unique_id])
|
||||
output = [to_list(output[i]) for output in outputs]
|
||||
|
||||
if len(output) >= 5:
|
||||
start_logits = output[0]
|
||||
start_top_index = output[1]
|
||||
end_logits = output[2]
|
||||
end_top_index = output[3],
|
||||
cls_logits = output[4]
|
||||
|
||||
result = SquadResult(
|
||||
unique_id, start_logits, end_logits,
|
||||
start_top_index=start_top_index,
|
||||
end_top_index=end_top_index,
|
||||
cls_logits=cls_logits
|
||||
)
|
||||
|
||||
else:
|
||||
start_logits, end_logits = output
|
||||
result = SquadResult(
|
||||
unique_id, start_logits, end_logits
|
||||
)
|
||||
|
||||
all_results.append(result)
|
||||
|
||||
evalTime = timeit.default_timer() - start_time
|
||||
|
||||
@@ -446,72 +446,21 @@ class SquadFeatures(object):
|
||||
self.end_position = end_position
|
||||
|
||||
|
||||
|
||||
class SquadResult(object):
|
||||
"""
|
||||
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
|
||||
|
||||
Args:
|
||||
result: The result output by a model on a SQuAD inference. These results may be complex (5 values) as the ones output by
|
||||
XLNet or XLM or may be simple like the other models (2 values). They may be passed as a list or as a dict, with the
|
||||
following accepted formats:
|
||||
|
||||
`dict` output by a simple model:
|
||||
{
|
||||
"start_logits": int,
|
||||
"end_logits": int,
|
||||
"unique_id": string
|
||||
}
|
||||
`list` output by a simple model:
|
||||
[start_logits, end_logits, unique_id]
|
||||
|
||||
`dict` output by a complex model:
|
||||
{
|
||||
"start_top_log_probs": float,
|
||||
"start_top_index": int,
|
||||
"end_top_log_probs": float,
|
||||
"end_top_index": int,
|
||||
"cls_logits": int,
|
||||
"unique_id": string
|
||||
}
|
||||
`list` output by a complex model:
|
||||
[start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, unique_id]
|
||||
|
||||
See `run_squad.py` for an example.
|
||||
unique_id: The unique identifier corresponding to that example.
|
||||
start_logits: The logits corresponding to the start of the answer
|
||||
end_logits: The logits corresponding to the end of the answer
|
||||
"""
|
||||
def __init__(self, result):
|
||||
if isinstance(result, dict):
|
||||
if "start_logits" in result and "end_logits" in result:
|
||||
self.start_logits = result["start_logits"]
|
||||
self.end_logits = result["end_logits"]
|
||||
|
||||
elif "start_top_log_probs" in result and "start_top_index" in result:
|
||||
self.start_top_log_probs = result["start_top_log_probs"]
|
||||
self.start_top_index = result["start_top_index"]
|
||||
self.end_top_log_probs = result["end_top_log_probs"]
|
||||
self.end_top_index = result["end_top_index"]
|
||||
self.cls_logits = result["cls_logits"]
|
||||
|
||||
else:
|
||||
raise ValueError("SquadResult instantiated with wrong values.")
|
||||
|
||||
self.unique_id = result["unique_id"]
|
||||
elif isinstance(result, list):
|
||||
if len(result) == 3:
|
||||
self.start_logits = result[0]
|
||||
self.end_logits = result[1]
|
||||
|
||||
elif len(result) == 6:
|
||||
self.start_top_log_probs = result[0]
|
||||
self.start_top_index = result[1]
|
||||
self.end_top_log_probs = result[2]
|
||||
self.end_top_index = result[3]
|
||||
self.cls_logits = result[4]
|
||||
|
||||
else:
|
||||
raise ValueError("SquadResult instantiated with wrong values.")
|
||||
|
||||
self.unique_id = result[-1]
|
||||
|
||||
else:
|
||||
raise ValueError("SquadResult instantiated with wrong values. Should be a dictionary or a list.")
|
||||
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
|
||||
self.start_top_log_probs = start_logits
|
||||
self.end_top_log_probs = end_logits
|
||||
self.unique_id = unique_id
|
||||
|
||||
if start_top_index:
|
||||
self.start_top_index = start_top_index
|
||||
self.end_top_index = end_top_index
|
||||
self.cls_logits = cls_logits
|
||||
Reference in New Issue
Block a user