Added SquadResult
This commit is contained in:
@@ -425,3 +425,74 @@ class SquadFeatures(object):
|
||||
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
|
||||
|
||||
|
||||
class SquadResult(object):
|
||||
"""
|
||||
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
|
||||
|
||||
Args:
|
||||
result: The result output by a model on a SQuAD inference. These results may be complex (5 values) as the ones output by
|
||||
XLNet or XLM or may be simple like the other models (2 values). They may be passed as a list or as a dict, with the
|
||||
following accepted formats:
|
||||
|
||||
`dict` output by a simple model:
|
||||
{
|
||||
"start_logits": int,
|
||||
"end_logits": int,
|
||||
"unique_id": string
|
||||
}
|
||||
`list` output by a simple model:
|
||||
[start_logits, end_logits, unique_id]
|
||||
|
||||
`dict` output by a complex model:
|
||||
{
|
||||
"start_top_log_probs": float,
|
||||
"start_top_index": int,
|
||||
"end_top_log_probs": float,
|
||||
"end_top_index": int,
|
||||
"cls_logits": int,
|
||||
"unique_id": string
|
||||
}
|
||||
`list` output by a complex model:
|
||||
[start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, unique_id]
|
||||
|
||||
See `run_squad.py` for an example.
|
||||
"""
|
||||
def __init__(self, result):
|
||||
if isinstance(result, dict):
|
||||
if "start_logits" in result and "end_logits" in result:
|
||||
self.start_logits = result["start_logits"]
|
||||
self.end_logits = result["end_logits"]
|
||||
|
||||
elif "start_top_log_probs" in result and "start_top_index" in result:
|
||||
self.start_top_log_probs = result["start_top_log_probs"]
|
||||
self.start_top_index = result["start_top_index"]
|
||||
self.end_top_log_probs = result["end_top_log_probs"]
|
||||
self.end_top_index = result["end_top_index"]
|
||||
self.cls_logits = result["cls_logits"]
|
||||
|
||||
else:
|
||||
raise ValueError("SquadResult instantiated with wrong values.")
|
||||
|
||||
self.unique_id = result["unique_id"]
|
||||
elif isinstance(result, list):
|
||||
if len(result) == 3:
|
||||
self.start_logits = result[0]
|
||||
self.end_logits = result[1]
|
||||
|
||||
elif len(result) == 6:
|
||||
self.start_top_log_probs = result[0]
|
||||
self.start_top_index = result[1]
|
||||
self.end_top_log_probs = result[2]
|
||||
self.end_top_index = result[3]
|
||||
self.cls_logits = result[4]
|
||||
|
||||
else:
|
||||
raise ValueError("SquadResult instantiated with wrong values.")
|
||||
|
||||
self.unique_id = result[-1]
|
||||
|
||||
else:
|
||||
raise ValueError("SquadResult instantiated with wrong values. Should be a dictionary or a list.")
|
||||
|
||||
Reference in New Issue
Block a user