fix(pipelines): QA pipeline returns fewer than top_k results in batch mode (#39193)
* fixing the bug * Try a simpler approach * make fixup --------- Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -556,8 +556,18 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
output["attention_mask"].numpy() if output.get("attention_mask", None) is not None else None
|
output["attention_mask"].numpy() if output.get("attention_mask", None) is not None else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pre_topk = (
|
||||||
|
top_k * 2 + 10 if align_to_words else top_k
|
||||||
|
) # Some candidates may be deleted if we align to words
|
||||||
starts, ends, scores, min_null_score = select_starts_ends(
|
starts, ends, scores, min_null_score = select_starts_ends(
|
||||||
start_, end_, p_mask, attention_mask, min_null_score, top_k, handle_impossible_answer, max_answer_len
|
start_,
|
||||||
|
end_,
|
||||||
|
p_mask,
|
||||||
|
attention_mask,
|
||||||
|
min_null_score,
|
||||||
|
pre_topk,
|
||||||
|
handle_impossible_answer,
|
||||||
|
max_answer_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.tokenizer.is_fast:
|
if not self.tokenizer.is_fast:
|
||||||
|
|||||||
@@ -168,10 +168,11 @@ class QAPipelineTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
outputs = question_answerer(
|
outputs = question_answerer(
|
||||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
question="Where was HuggingFace founded ?",
|
||||||
|
context="HuggingFace was founded in Paris.",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
self.assertEqual(nested_simplify(outputs), {"score": 0.063, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt_fp16(self):
|
def test_small_model_pt_fp16(self):
|
||||||
@@ -182,10 +183,11 @@ class QAPipelineTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
outputs = question_answerer(
|
outputs = question_answerer(
|
||||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
question="Where was HuggingFace founded ?",
|
||||||
|
context="HuggingFace was founded in Paris.",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
self.assertEqual(nested_simplify(outputs), {"score": 0.063, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt_bf16(self):
|
def test_small_model_pt_bf16(self):
|
||||||
@@ -196,10 +198,11 @@ class QAPipelineTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
outputs = question_answerer(
|
outputs = question_answerer(
|
||||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
question="Where was HuggingFace founded ?",
|
||||||
|
context="HuggingFace was founded in Paris.",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
self.assertEqual(nested_simplify(outputs), {"score": 0.063, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt_iterator(self):
|
def test_small_model_pt_iterator(self):
|
||||||
@@ -211,7 +214,9 @@ class QAPipelineTests(unittest.TestCase):
|
|||||||
yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."}
|
yield {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."}
|
||||||
|
|
||||||
for outputs in pipe(data()):
|
for outputs in pipe(data()):
|
||||||
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs), {"score": 0.063, "start": 0, "end": 11, "answer": "HuggingFace"}
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt_softmax_trick(self):
|
def test_small_model_pt_softmax_trick(self):
|
||||||
@@ -242,10 +247,11 @@ class QAPipelineTests(unittest.TestCase):
|
|||||||
question_answerer.postprocess = ensure_large_logits_postprocess
|
question_answerer.postprocess = ensure_large_logits_postprocess
|
||||||
|
|
||||||
outputs = question_answerer(
|
outputs = question_answerer(
|
||||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
question="Where was HuggingFace founded ?",
|
||||||
|
context="HuggingFace was founded in Paris.",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"})
|
self.assertEqual(nested_simplify(outputs), {"score": 0.111, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user