Add support for the null answer in QuestionAnsweringPipeline (#3441)
* Add support for the null answer in `QuestionAnsweringPipeline` * black * Fix min null score computation * Fix a PR comment
This commit is contained in:
@@ -944,6 +944,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
kwargs.setdefault("max_answer_len", 15)
|
kwargs.setdefault("max_answer_len", 15)
|
||||||
kwargs.setdefault("max_seq_len", 384)
|
kwargs.setdefault("max_seq_len", 384)
|
||||||
kwargs.setdefault("max_question_len", 64)
|
kwargs.setdefault("max_question_len", 64)
|
||||||
|
kwargs.setdefault("handle_impossible_answer", False)
|
||||||
|
|
||||||
if kwargs["topk"] < 1:
|
if kwargs["topk"] < 1:
|
||||||
raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
|
raise ValueError("topk parameter should be >= 1 (got {})".format(kwargs["topk"]))
|
||||||
@@ -982,6 +983,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
start, end = self.model(**fw_args)
|
start, end = self.model(**fw_args)
|
||||||
start, end = start.cpu().numpy(), end.cpu().numpy()
|
start, end = start.cpu().numpy(), end.cpu().numpy()
|
||||||
|
|
||||||
|
min_null_score = 1000000 # large and positive
|
||||||
answers = []
|
answers = []
|
||||||
for (feature, start_, end_) in zip(features, start, end):
|
for (feature, start_, end_) in zip(features, start, end):
|
||||||
# Normalize logits and spans to retrieve the answer
|
# Normalize logits and spans to retrieve the answer
|
||||||
@@ -994,8 +996,9 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
end_ * np.abs(np.array(feature.p_mask) - 1),
|
end_ * np.abs(np.array(feature.p_mask) - 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO : What happens if not possible
|
if kwargs["handle_impossible_answer"]:
|
||||||
# Mask CLS
|
min_null_score = min(min_null_score, (start_[0] * end_[0]).item())
|
||||||
|
|
||||||
start_[0] = end_[0] = 0
|
start_[0] = end_[0] = 0
|
||||||
|
|
||||||
starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
|
starts, ends, scores = self.decode(start_, end_, kwargs["topk"], kwargs["max_answer_len"])
|
||||||
@@ -1013,6 +1016,10 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
}
|
}
|
||||||
for s, e, score in zip(starts, ends, scores)
|
for s, e, score in zip(starts, ends, scores)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if kwargs["handle_impossible_answer"]:
|
||||||
|
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})
|
||||||
|
|
||||||
answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
|
answers = sorted(answers, key=lambda x: x["score"], reverse=True)[: kwargs["topk"]]
|
||||||
all_answers += answers
|
all_answers += answers
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user