From cca75e788485e8a2a1c44a445c6aba0fb2dfaf56 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 4 Dec 2019 15:42:29 -0500 Subject: [PATCH] Kill the demon spawn --- examples/run_squad.py | 23 +++++++- transformers/data/processors/squad.py | 75 +++++---------------------- 2 files changed, 34 insertions(+), 64 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index a9ef5c6ba2..2f86322196 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -248,7 +248,28 @@ def evaluate(args, model, tokenizer, prefix=""): eval_feature = features[example_index.item()] unique_id = int(eval_feature.unique_id) - result = SquadResult([to_list(output[i]) for output in outputs] + [unique_id]) + output = [to_list(output[i]) for output in outputs] + + if len(output) >= 5: + start_logits = output[0] + start_top_index = output[1] + end_logits = output[2] + end_top_index = output[3], + cls_logits = output[4] + + result = SquadResult( + unique_id, start_logits, end_logits, + start_top_index=start_top_index, + end_top_index=end_top_index, + cls_logits=cls_logits + ) + + else: + start_logits, end_logits = output + result = SquadResult( + unique_id, start_logits, end_logits + ) + all_results.append(result) evalTime = timeit.default_timer() - start_time diff --git a/transformers/data/processors/squad.py b/transformers/data/processors/squad.py index 2e50ac8a8c..9306189eb4 100644 --- a/transformers/data/processors/squad.py +++ b/transformers/data/processors/squad.py @@ -446,72 +446,21 @@ class SquadFeatures(object): self.end_position = end_position - class SquadResult(object): """ Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset. Args: - result: The result output by a model on a SQuAD inference. These results may be complex (5 values) as the ones output by - XLNet or XLM or may be simple like the other models (2 values). They may be passed as a list or as a dict, with the - following accepted formats: - - `dict` output by a simple model: - { - "start_logits": int, - "end_logits": int, - "unique_id": string - } - `list` output by a simple model: - [start_logits, end_logits, unique_id] - - `dict` output by a complex model: - { - "start_top_log_probs": float, - "start_top_index": int, - "end_top_log_probs": float, - "end_top_index": int, - "cls_logits": int, - "unique_id": string - } - `list` output by a complex model: - [start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, unique_id] - - See `run_squad.py` for an example. + unique_id: The unique identifier corresponding to that example. + start_logits: The logits corresponding to the start of the answer + end_logits: The logits corresponding to the end of the answer """ - def __init__(self, result): - if isinstance(result, dict): - if "start_logits" in result and "end_logits" in result: - self.start_logits = result["start_logits"] - self.end_logits = result["end_logits"] - - elif "start_top_log_probs" in result and "start_top_index" in result: - self.start_top_log_probs = result["start_top_log_probs"] - self.start_top_index = result["start_top_index"] - self.end_top_log_probs = result["end_top_log_probs"] - self.end_top_index = result["end_top_index"] - self.cls_logits = result["cls_logits"] - - else: - raise ValueError("SquadResult instantiated with wrong values.") - - self.unique_id = result["unique_id"] - elif isinstance(result, list): - if len(result) == 3: - self.start_logits = result[0] - self.end_logits = result[1] - - elif len(result) == 6: - self.start_top_log_probs = result[0] - self.start_top_index = result[1] - self.end_top_log_probs = result[2] - self.end_top_index = result[3] - self.cls_logits = result[4] - - else: - raise ValueError("SquadResult instantiated with wrong values.") - - self.unique_id = result[-1] - - else: - raise ValueError("SquadResult instantiated with wrong values. Should be a dictionary or a list.") + def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None): + self.start_top_log_probs = start_logits + self.end_top_log_probs = end_logits + self.unique_id = unique_id + + if start_top_index: + self.start_top_index = start_top_index + self.end_top_index = end_top_index + self.cls_logits = cls_logits \ No newline at end of file