Change document question answering pipeline to always return an array (#19071)
Co-authored-by: Ankur Goyal <ankur@impira.com>
This commit is contained in:
@@ -383,8 +383,6 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
|
|||||||
answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs)
|
answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs)
|
||||||
|
|
||||||
answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k]
|
answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k]
|
||||||
if len(answers) == 1:
|
|
||||||
return answers[0]
|
|
||||||
return answers
|
return answers
|
||||||
|
|
||||||
def postprocess_donut(self, model_outputs, **kwargs):
|
def postprocess_donut(self, model_outputs, **kwargs):
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
|||||||
image = INVOICE_URL
|
image = INVOICE_URL
|
||||||
question = "What is the invoice number?"
|
question = "What is the invoice number?"
|
||||||
outputs = dqa_pipeline(image=image, question=question, top_k=2)
|
outputs = dqa_pipeline(image=image, question=question, top_k=2)
|
||||||
self.assertEqual(nested_simplify(outputs, decimals=4), {"answer": "us-001"})
|
self.assertEqual(nested_simplify(outputs, decimals=4), [{"answer": "us-001"}])
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@unittest.skip("Document question answering not implemented in TF")
|
@unittest.skip("Document question answering not implemented in TF")
|
||||||
|
|||||||
Reference in New Issue
Block a user