LogSumExp trick question_answering pipeline. (#17143)

* LogSumExp trick `question_answering` pipeline.

* Adding a failing test.
This commit is contained in:
Nicolas Patry
2022-05-10 10:03:55 +02:00
committed by GitHub
parent d719bcd46a
commit 6d80c92c77
2 changed files with 40 additions and 2 deletions

View File

@@ -111,12 +111,47 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
question_answerer = pipeline(
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
)
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
@require_torch
def test_small_model_pt_softmax_trick(self):
question_answerer = pipeline(
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
)
real_postprocess = question_answerer.postprocess
# Tweak start and stop to make sure we encounter the softmax logits
# bug.
def ensure_large_logits_postprocess(
model_outputs,
top_k=1,
handle_impossible_answer=False,
max_answer_len=15,
):
for output in model_outputs:
output["start"] = output["start"] * 1e6
output["end"] = output["end"] * 1e6
return real_postprocess(
model_outputs,
top_k=top_k,
handle_impossible_answer=handle_impossible_answer,
max_answer_len=max_answer_len,
)
question_answerer.postprocess = ensure_large_logits_postprocess
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)
self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"})
@slow
@require_torch
def test_small_model_long_context_cls_slow(self):