From c19727fd38ed6e36f836606705217477a2f4b5c9 Mon Sep 17 00:00:00 2001 From: Santiago Castro Date: Fri, 17 Apr 2020 11:17:21 -0400 Subject: [PATCH] Add support for the null answer in `QuestionAnsweringPipeline` (#3441) * Add support for the null answer in `QuestionAnsweringPipeline` * black * Fix min null score computation * Fix a PR comment --- src/transformers/pipelines.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index b1d4fa81a4..2b8084edba 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -944,6 +944,7 @@ class QuestionAnsweringPipeline(Pipeline): kwargs.setdefault("max_answer_len", 15) kwargs.setdefault("max_seq_len", 384) kwargs.setdefault("max_question_len", 64) + kwargs.setdefault("handle_impossible_answer", False) if kwargs["topk"] < 1: raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"])) @@ -982,6 +983,7 @@ class QuestionAnsweringPipeline(Pipeline): start, end = self.model(**fw_args) start, end = start.cpu().numpy(), end.cpu().numpy() + min_null_score = 1000000 # large and positive answers = [] for (feature, start_, end_) in zip(features, start, end): # Normalize logits and spans to retrieve the answer @@ -994,8 +996,9 @@ class QuestionAnsweringPipeline(Pipeline): end_ * np.abs(np.array(feature.p_mask) - 1), ) - # TODO : What happens if not possible - # Mask CLS + if kwargs["handle_impossible_answer"]: + min_null_score = min(min_null_score, (start_[0] * end_[0]).item()) + start_[0] = end_[0] = 0 starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"]) @@ -1013,6 +1016,10 @@ class QuestionAnsweringPipeline(Pipeline): } for s, e, score in zip(starts, ends, scores) ] + + if kwargs["handle_impossible_answer"]: + answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""}) + answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]] all_answers += answers