Add support for the null answer in QuestionAnsweringPipeline (#3441)
* Add support for the null answer in `QuestionAnsweringPipeline` * black * Fix min null score computation * Fix a PR comment
This commit is contained in:
@@ -944,6 +944,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
kwargs.setdefault("max_answer_len", 15)
|
||||
kwargs.setdefault("max_seq_len", 384)
|
||||
kwargs.setdefault("max_question_len", 64)
|
||||
kwargs.setdefault("handle_impossible_answer", False)
|
||||
|
||||
if kwargs["topk"] < 1:
|
||||
raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
|
||||
@@ -982,6 +983,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
start, end = self.model(**fw_args)
|
||||
start, end = start.cpu().numpy(), end.cpu().numpy()
|
||||
|
||||
min_null_score = 1000000 # large and positive
|
||||
answers = []
|
||||
for (feature, start_, end_) in zip(features, start, end):
|
||||
# Normalize logits and spans to retrieve the answer
|
||||
@@ -994,8 +996,9 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
end_ * np.abs(np.array(feature.p_mask) - 1),
|
||||
)
|
||||
|
||||
# TODO : What happens if not possible
|
||||
# Mask CLS
|
||||
if kwargs["handle_impossible_answer"]:
|
||||
min_null_score = min(min_null_score, (start_[0] * end_[0]).item())
|
||||
|
||||
start_[0] = end_[0] = 0
|
||||
|
||||
starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
|
||||
@@ -1013,6 +1016,10 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
}
|
||||
for s, e, score in zip(starts, ends, scores)
|
||||
]
|
||||
|
||||
if kwargs["handle_impossible_answer"]:
|
||||
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})
|
||||
|
||||
answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
|
||||
all_answers += answers
|
||||
|
||||
|
||||
Reference in New Issue
Block a user