From cdfe6164b32d52f5d2125a29f0762e18ec5ac708 Mon Sep 17 00:00:00 2001 From: Yusuf Shihata <166058059+yushi2006@users.noreply.github.com> Date: Thu, 17 Jul 2025 11:24:30 +0300 Subject: [PATCH] fix(pipelines): QA pipeline returns fewer than top_k results in batch mode (#39193) * fixing the bug * Try a simpler approach * make fixup --------- Co-authored-by: Matt --- .../pipelines/question_answering.py | 12 +++++++++- .../test_pipelines_question_answering.py | 24 ++++++++++++------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index 2fe92d747e..2eee80a907 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -556,8 +556,18 @@ class QuestionAnsweringPipeline(ChunkPipeline): output["attention_mask"].numpy() if output.get("attention_mask", None) is not None else None ) + pre_topk = ( + top_k * 2 + 10 if align_to_words else top_k + ) # Some candidates may be deleted if we align to words starts, ends, scores, min_null_score = select_starts_ends( - start_, end_, p_mask, attention_mask, min_null_score, top_k, handle_impossible_answer, max_answer_len + start_, + end_, + p_mask, + attention_mask, + min_null_score, + pre_topk, + handle_impossible_answer, + max_answer_len, ) if not self.tokenizer.is_fast: diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index 2de1de20d2..d46dd489c5 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -168,10 +168,11 @@ class QAPipelineTests(unittest.TestCase): ) outputs = question_answerer( - question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." + question="Where was HuggingFace founded ?", + context="HuggingFace was founded in Paris.", ) - self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + self.assertEqual(nested_simplify(outputs), {"score": 0.063, "start": 0, "end": 11, "answer": "HuggingFace"}) @require_torch def test_small_model_pt_fp16(self): @@ -182,10 +183,11 @@ class QAPipelineTests(unittest.TestCase): ) outputs = question_answerer( - question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." + question="Where was HuggingFace founded ?", + context="HuggingFace was founded in Paris.", ) - self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + self.assertEqual(nested_simplify(outputs), {"score": 0.063, "start": 0, "end": 11, "answer": "HuggingFace"}) @require_torch def test_small_model_pt_bf16(self): @@ -196,10 +198,11 @@ class QAPipelineTests(unittest.TestCase): ) outputs = question_answerer( - question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." + question="Where was HuggingFace founded ?", + context="HuggingFace was founded in Paris.", ) - self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + self.assertEqual(nested_simplify(outputs), {"score": 0.063, "start": 0, "end": 11, "answer": "HuggingFace"}) @require_torch def test_small_model_pt_iterator(self): @@ -211,7 +214,9 @@ class QAPipelineTests(unittest.TestCase): yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."} for outputs in pipe(data()): - self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + self.assertEqual( + nested_simplify(outputs), {"score": 0.063, "start": 0, "end": 11, "answer": "HuggingFace"} + ) @require_torch def test_small_model_pt_softmax_trick(self): @@ -242,10 +247,11 @@ class QAPipelineTests(unittest.TestCase): question_answerer.postprocess = ensure_large_logits_postprocess outputs = question_answerer( - question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." + question="Where was HuggingFace founded ?", + context="HuggingFace was founded in Paris.", ) - self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"}) + self.assertEqual(nested_simplify(outputs), {"score": 0.111, "start": 0, "end": 11, "answer": "HuggingFace"}) @slow @require_torch