LogSumExp trick question_answering pipeline. (#17143)
* LogSumExp trick `question_answering` pipeline. * Adding a failing test.
This commit is contained in:
@@ -398,8 +398,11 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
end_ = np.where(undesired_tokens_mask, -10000.0, end_)
|
end_ = np.where(undesired_tokens_mask, -10000.0, end_)
|
||||||
|
|
||||||
# Normalize logits and spans to retrieve the answer
|
# Normalize logits and spans to retrieve the answer
|
||||||
start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
|
start_ = np.exp(start_ - start_.max(axis=-1, keepdims=True))
|
||||||
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))
|
start_ = start_ / start_.sum()
|
||||||
|
|
||||||
|
end_ = np.exp(end_ - end_.max(axis=-1, keepdims=True))
|
||||||
|
end_ = end_ / end_.sum()
|
||||||
|
|
||||||
if handle_impossible_answer:
|
if handle_impossible_answer:
|
||||||
min_null_score = min(min_null_score, (start_[0, 0] * end_[0, 0]).item())
|
min_null_score = min(min_null_score, (start_[0, 0] * end_[0, 0]).item())
|
||||||
|
|||||||
@@ -111,12 +111,47 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
|||||||
question_answerer = pipeline(
|
question_answerer = pipeline(
|
||||||
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
|
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = question_answerer(
|
outputs = question_answerer(
|
||||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_softmax_trick(self):
|
||||||
|
question_answerer = pipeline(
|
||||||
|
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
|
||||||
|
)
|
||||||
|
|
||||||
|
real_postprocess = question_answerer.postprocess
|
||||||
|
|
||||||
|
# Tweak start and stop to make sure we encounter the softmax logits
|
||||||
|
# bug.
|
||||||
|
def ensure_large_logits_postprocess(
|
||||||
|
model_outputs,
|
||||||
|
top_k=1,
|
||||||
|
handle_impossible_answer=False,
|
||||||
|
max_answer_len=15,
|
||||||
|
):
|
||||||
|
for output in model_outputs:
|
||||||
|
output["start"] = output["start"] * 1e6
|
||||||
|
output["end"] = output["end"] * 1e6
|
||||||
|
return real_postprocess(
|
||||||
|
model_outputs,
|
||||||
|
top_k=top_k,
|
||||||
|
handle_impossible_answer=handle_impossible_answer,
|
||||||
|
max_answer_len=max_answer_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
question_answerer.postprocess = ensure_large_logits_postprocess
|
||||||
|
|
||||||
|
outputs = question_answerer(
|
||||||
|
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_long_context_cls_slow(self):
|
def test_small_model_long_context_cls_slow(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user