From 6d80c92c77593dc674052b5a46431902e6adfe88 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 10 May 2022 10:03:55 +0200 Subject: [PATCH] LogSumExp trick `question_answering` pipeline. (#17143) * LogSumExp trick `question_answering` pipeline. * Adding a failing test. --- .../pipelines/question_answering.py | 7 ++-- .../test_pipelines_question_answering.py | 35 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index c629f703a0..bbffa3471f 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -398,8 +398,11 @@ class QuestionAnsweringPipeline(ChunkPipeline): end_ = np.where(undesired_tokens_mask, -10000.0, end_) # Normalize logits and spans to retrieve the answer - start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True))) - end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True))) + start_ = np.exp(start_ - start_.max(axis=-1, keepdims=True)) + start_ = start_ / start_.sum() + + end_ = np.exp(end_ - end_.max(axis=-1, keepdims=True)) + end_ = end_ / end_.sum() if handle_impossible_answer: min_null_score = min(min_null_score, (start_[0, 0] * end_[0, 0]).item()) diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index e37fa12776..844ed0b683 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -111,12 +111,47 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): question_answerer = pipeline( "question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad" ) + outputs = question_answerer( question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." ) self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + @require_torch + def test_small_model_pt_softmax_trick(self): + question_answerer = pipeline( + "question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad" + ) + + real_postprocess = question_answerer.postprocess + + # Tweak start and stop to make sure we encounter the softmax logits + # bug. + def ensure_large_logits_postprocess( + model_outputs, + top_k=1, + handle_impossible_answer=False, + max_answer_len=15, + ): + for output in model_outputs: + output["start"] = output["start"] * 1e6 + output["end"] = output["end"] * 1e6 + return real_postprocess( + model_outputs, + top_k=top_k, + handle_impossible_answer=handle_impossible_answer, + max_answer_len=max_answer_len, + ) + + question_answerer.postprocess = ensure_large_logits_postprocess + + outputs = question_answerer( + question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." + ) + + self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"}) + @slow @require_torch def test_small_model_long_context_cls_slow(self):