Added SquadResult

This commit is contained in:
LysandreJik
2019-12-03 15:00:49 -05:00
parent 1e9ac5a7cf
commit 285b1241e3

View File

@@ -425,3 +425,74 @@ class SquadFeatures(object):
self.start_position = start_position
self.end_position = end_position
class SquadResult(object):
"""
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
Args:
result: The result output by a model on a SQuAD inference. These results may be complex (5 values) as the ones output by
XLNet or XLM or may be simple like the other models (2 values). They may be passed as a list or as a dict, with the
following accepted formats:
`dict` output by a simple model:
{
"start_logits": int,
"end_logits": int,
"unique_id": string
}
`list` output by a simple model:
[start_logits, end_logits, unique_id]
`dict` output by a complex model:
{
"start_top_log_probs": float,
"start_top_index": int,
"end_top_log_probs": float,
"end_top_index": int,
"cls_logits": int,
"unique_id": string
}
`list` output by a complex model:
[start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, unique_id]
See `run_squad.py` for an example.
"""
def __init__(self, result):
if isinstance(result, dict):
if "start_logits" in result and "end_logits" in result:
self.start_logits = result["start_logits"]
self.end_logits = result["end_logits"]
elif "start_top_log_probs" in result and "start_top_index" in result:
self.start_top_log_probs = result["start_top_log_probs"]
self.start_top_index = result["start_top_index"]
self.end_top_log_probs = result["end_top_log_probs"]
self.end_top_index = result["end_top_index"]
self.cls_logits = result["cls_logits"]
else:
raise ValueError("SquadResult instantiated with wrong values.")
self.unique_id = result["unique_id"]
elif isinstance(result, list):
if len(result) == 3:
self.start_logits = result[0]
self.end_logits = result[1]
elif len(result) == 6:
self.start_top_log_probs = result[0]
self.start_top_index = result[1]
self.end_top_log_probs = result[2]
self.end_top_index = result[3]
self.cls_logits = result[4]
else:
raise ValueError("SquadResult instantiated with wrong values.")
self.unique_id = result[-1]
else:
raise ValueError("SquadResult instantiated with wrong values. Should be a dictionary or a list.")