From 6a9726ec0e3b8d3841441d911fe37a0538db4d3a Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 14 Sep 2022 16:13:20 +0200 Subject: [PATCH] Fix `DocumentQuestionAnsweringPipelineTests` (#19023) * Fix DocumentQuestionAnsweringPipelineTests Co-authored-by: ydshieh --- ...t_pipelines_document_question_answering.py | 37 ++++++++----------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/tests/pipelines/test_pipelines_document_question_answering.py b/tests/pipelines/test_pipelines_document_question_answering.py index 7bf8ec99fb..091f6c3c03 100644 --- a/tests/pipelines/test_pipelines_document_question_answering.py +++ b/tests/pipelines/test_pipelines_document_question_answering.py @@ -113,13 +113,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli question = "How many cats are there?" expected_output = [ - { - "score": 0.0001, - "answer": "2312/2019 DUE DATE 26102/2019 ay DESCRIPTION UNIT PRICE", - "start": 38, - "end": 45, - }, - {"score": 0.0001, "answer": "2312/2019 DUE", "start": 38, "end": 39}, + {"score": 0.0001, "answer": "oy 2312/2019", "start": 38, "end": 39}, + {"score": 0.0001, "answer": "oy 2312/2019 DUE", "start": 38, "end": 40}, ] outputs = dqa_pipeline(image=image, question=question, top_k=2) self.assertEqual(nested_simplify(outputs, decimals=4), expected_output) @@ -170,8 +165,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli self.assertEqual( nested_simplify(outputs, decimals=4), [ - {"score": 0.9966, "answer": "us-001", "start": 15, "end": 15}, - {"score": 0.0009, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9944, "answer": "us-001", "start": 16, "end": 16}, + {"score": 0.0009, "answer": "us-001", "start": 16, "end": 16}, ], ) @@ -179,8 +174,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli self.assertEqual( nested_simplify(outputs, decimals=4), [ - {"score": 0.9966, "answer": "us-001", "start": 15, "end": 15}, - {"score": 0.0009, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9944, "answer": "us-001", "start": 16, "end": 16}, + {"score": 0.0009, "answer": "us-001", "start": 16, "end": 16}, ], ) @@ -191,8 +186,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli nested_simplify(outputs, decimals=4), [ [ - {"score": 0.9966, "answer": "us-001", "start": 15, "end": 15}, - {"score": 0.0009, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9944, "answer": "us-001", "start": 16, "end": 16}, + {"score": 0.0009, "answer": "us-001", "start": 16, "end": 16}, ], ] * 2, @@ -219,8 +214,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli self.assertEqual( nested_simplify(outputs, decimals=4), [ - {"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, - {"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, + {"score": 0.4251, "answer": "us-001", "start": 16, "end": 16}, + {"score": 0.0819, "answer": "1110212019", "start": 23, "end": 23}, ], ) @@ -228,8 +223,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli self.assertEqual( nested_simplify(outputs, decimals=4), [ - {"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, - {"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, + {"score": 0.4251, "answer": "us-001", "start": 16, "end": 16}, + {"score": 0.0819, "answer": "1110212019", "start": 23, "end": 23}, ], ) @@ -240,8 +235,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli nested_simplify(outputs, decimals=4), [ [ - {"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, - {"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, + {"score": 0.4251, "answer": "us-001", "start": 16, "end": 16}, + {"score": 0.0819, "answer": "1110212019", "start": 23, "end": 23}, ] ] * 2, @@ -254,8 +249,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli self.assertEqual( nested_simplify(outputs, decimals=4), [ - {"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, - {"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, + {"score": 0.4251, "answer": "us-001", "start": 16, "end": 16}, + {"score": 0.0819, "answer": "1110212019", "start": 23, "end": 23}, ], )