From 285b1241e38cdafb6b0dadd1d1afc19493318074 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Tue, 3 Dec 2019 15:00:49 -0500 Subject: [PATCH] Added SquadResult --- transformers/data/processors/squad.py | 71 +++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/transformers/data/processors/squad.py b/transformers/data/processors/squad.py index f414d41925..afbe4270f5 100644 --- a/transformers/data/processors/squad.py +++ b/transformers/data/processors/squad.py @@ -425,3 +425,74 @@ class SquadFeatures(object): self.start_position = start_position self.end_position = end_position + + + +class SquadResult(object): + """ + Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset. + + Args: + result: The result output by a model on a SQuAD inference. These results may be complex (5 values) as the ones output by + XLNet or XLM or may be simple like the other models (2 values). They may be passed as a list or as a dict, with the + following accepted formats: + + `dict` output by a simple model: + { + "start_logits": int, + "end_logits": int, + "unique_id": string + } + `list` output by a simple model: + [start_logits, end_logits, unique_id] + + `dict` output by a complex model: + { + "start_top_log_probs": float, + "start_top_index": int, + "end_top_log_probs": float, + "end_top_index": int, + "cls_logits": int, + "unique_id": string + } + `list` output by a complex model: + [start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, unique_id] + + See `run_squad.py` for an example. + """ + def __init__(self, result): + if isinstance(result, dict): + if "start_logits" in result and "end_logits" in result: + self.start_logits = result["start_logits"] + self.end_logits = result["end_logits"] + + elif "start_top_log_probs" in result and "start_top_index" in result: + self.start_top_log_probs = result["start_top_log_probs"] + self.start_top_index = result["start_top_index"] + self.end_top_log_probs = result["end_top_log_probs"] + self.end_top_index = result["end_top_index"] + self.cls_logits = result["cls_logits"] + + else: + raise ValueError("SquadResult instantiated with wrong values.") + + self.unique_id = result["unique_id"] + elif isinstance(result, list): + if len(result) == 3: + self.start_logits = result[0] + self.end_logits = result[1] + + elif len(result) == 6: + self.start_top_log_probs = result[0] + self.start_top_index = result[1] + self.end_top_log_probs = result[2] + self.end_top_index = result[3] + self.cls_logits = result[4] + + else: + raise ValueError("SquadResult instantiated with wrong values.") + + self.unique_id = result[-1] + + else: + raise ValueError("SquadResult instantiated with wrong values. Should be a dictionary or a list.")