From 729b569531f644901046bc51502e792b5984bd48 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Jan 2025 00:18:07 +0800 Subject: [PATCH] fix document qa bf16 pipeline (#35456) * fix document qa bf16 pipeline Signed-off-by: jiqing-feng * add test Signed-off-by: jiqing-feng * fix test Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng --- .../pipelines/document_question_answering.py | 5 ++ ...t_pipelines_document_question_answering.py | 47 ++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py index c176d841e2..41cd5d5d85 100644 --- a/src/transformers/pipelines/document_question_answering.py +++ b/src/transformers/pipelines/document_question_answering.py @@ -485,6 +485,11 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): for output in model_outputs: words = output["words"] + if self.framework == "pt" and output["start_logits"].dtype in (torch.bfloat16, torch.float16): + output["start_logits"] = output["start_logits"].float() + if self.framework == "pt" and output["end_logits"].dtype in (torch.bfloat16, torch.float16): + output["end_logits"] = output["end_logits"].float() + starts, ends, scores, min_null_score = select_starts_ends( start=output["start_logits"], end=output["end_logits"], diff --git a/tests/pipelines/test_pipelines_document_question_answering.py b/tests/pipelines/test_pipelines_document_question_answering.py index 305fe0c558..85d528ce91 100644 --- a/tests/pipelines/test_pipelines_document_question_answering.py +++ b/tests/pipelines/test_pipelines_document_question_answering.py @@ -14,7 +14,12 @@ import unittest -from transformers import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, AutoTokenizer, is_vision_available +from transformers import ( + MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, + AutoTokenizer, + is_torch_available, + is_vision_available, +) from transformers.pipelines import DocumentQuestionAnsweringPipeline, pipeline from transformers.pipelines.document_question_answering import apply_tesseract from transformers.testing_utils import ( @@ -24,6 +29,7 @@ from transformers.testing_utils import ( require_pytesseract, require_tf, require_torch, + require_torch_bf16, require_vision, slow, ) @@ -31,6 +37,9 @@ from transformers.testing_utils import ( from .test_pipelines_common import ANY +if is_torch_available(): + import torch + if is_vision_available(): from PIL import Image @@ -145,6 +154,42 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase): outputs = dqa_pipeline(image=image, question=question, words=words, boxes=boxes, top_k=2) self.assertEqual(outputs, []) + @require_torch + @require_torch_bf16 + @require_detectron2 + @require_pytesseract + def test_small_model_pt_bf16(self): + dqa_pipeline = pipeline( + "document-question-answering", + model="hf-internal-testing/tiny-random-layoutlmv2-for-dqa-test", + torch_dtype=torch.bfloat16, + ) + image = INVOICE_URL + question = "How many cats are there?" + + expected_output = [ + {"score": 0.0001, "answer": "oy 2312/2019", "start": 38, "end": 39}, + {"score": 0.0001, "answer": "oy 2312/2019 DUE", "start": 38, "end": 40}, + ] + outputs = dqa_pipeline(image=image, question=question, top_k=2) + self.assertEqual(nested_simplify(outputs, decimals=4), expected_output) + + outputs = dqa_pipeline({"image": image, "question": question}, top_k=2) + self.assertEqual(nested_simplify(outputs, decimals=4), expected_output) + + # This image does not detect ANY text in it, meaning layoutlmv2 should fail. + # Empty answer probably + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + outputs = dqa_pipeline(image=image, question=question, top_k=2) + self.assertEqual(outputs, []) + + # We can optionnally pass directly the words and bounding boxes + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + words = [] + boxes = [] + outputs = dqa_pipeline(image=image, question=question, words=words, boxes=boxes, top_k=2) + self.assertEqual(outputs, []) + # TODO: Enable this once hf-internal-testing/tiny-random-donut is implemented # @require_torch # def test_small_model_pt_donut(self):