fix document qa bf16 pipeline (#35456)
* fix document qa bf16 pipeline Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
@@ -485,6 +485,11 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
for output in model_outputs:
|
for output in model_outputs:
|
||||||
words = output["words"]
|
words = output["words"]
|
||||||
|
|
||||||
|
if self.framework == "pt" and output["start_logits"].dtype in (torch.bfloat16, torch.float16):
|
||||||
|
output["start_logits"] = output["start_logits"].float()
|
||||||
|
if self.framework == "pt" and output["end_logits"].dtype in (torch.bfloat16, torch.float16):
|
||||||
|
output["end_logits"] = output["end_logits"].float()
|
||||||
|
|
||||||
starts, ends, scores, min_null_score = select_starts_ends(
|
starts, ends, scores, min_null_score = select_starts_ends(
|
||||||
start=output["start_logits"],
|
start=output["start_logits"],
|
||||||
end=output["end_logits"],
|
end=output["end_logits"],
|
||||||
|
|||||||
@@ -14,7 +14,12 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, AutoTokenizer, is_vision_available
|
from transformers import (
|
||||||
|
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING,
|
||||||
|
AutoTokenizer,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
)
|
||||||
from transformers.pipelines import DocumentQuestionAnsweringPipeline, pipeline
|
from transformers.pipelines import DocumentQuestionAnsweringPipeline, pipeline
|
||||||
from transformers.pipelines.document_question_answering import apply_tesseract
|
from transformers.pipelines.document_question_answering import apply_tesseract
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -24,6 +29,7 @@ from transformers.testing_utils import (
|
|||||||
require_pytesseract,
|
require_pytesseract,
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_bf16,
|
||||||
require_vision,
|
require_vision,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
@@ -31,6 +37,9 @@ from transformers.testing_utils import (
|
|||||||
from .test_pipelines_common import ANY
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@@ -145,6 +154,42 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase):
|
|||||||
outputs = dqa_pipeline(image=image, question=question, words=words, boxes=boxes, top_k=2)
|
outputs = dqa_pipeline(image=image, question=question, words=words, boxes=boxes, top_k=2)
|
||||||
self.assertEqual(outputs, [])
|
self.assertEqual(outputs, [])
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_torch_bf16
|
||||||
|
@require_detectron2
|
||||||
|
@require_pytesseract
|
||||||
|
def test_small_model_pt_bf16(self):
|
||||||
|
dqa_pipeline = pipeline(
|
||||||
|
"document-question-answering",
|
||||||
|
model="hf-internal-testing/tiny-random-layoutlmv2-for-dqa-test",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
image = INVOICE_URL
|
||||||
|
question = "How many cats are there?"
|
||||||
|
|
||||||
|
expected_output = [
|
||||||
|
{"score": 0.0001, "answer": "oy 2312/2019", "start": 38, "end": 39},
|
||||||
|
{"score": 0.0001, "answer": "oy 2312/2019 DUE", "start": 38, "end": 40},
|
||||||
|
]
|
||||||
|
outputs = dqa_pipeline(image=image, question=question, top_k=2)
|
||||||
|
self.assertEqual(nested_simplify(outputs, decimals=4), expected_output)
|
||||||
|
|
||||||
|
outputs = dqa_pipeline({"image": image, "question": question}, top_k=2)
|
||||||
|
self.assertEqual(nested_simplify(outputs, decimals=4), expected_output)
|
||||||
|
|
||||||
|
# This image does not detect ANY text in it, meaning layoutlmv2 should fail.
|
||||||
|
# Empty answer probably
|
||||||
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
outputs = dqa_pipeline(image=image, question=question, top_k=2)
|
||||||
|
self.assertEqual(outputs, [])
|
||||||
|
|
||||||
|
# We can optionnally pass directly the words and bounding boxes
|
||||||
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
words = []
|
||||||
|
boxes = []
|
||||||
|
outputs = dqa_pipeline(image=image, question=question, words=words, boxes=boxes, top_k=2)
|
||||||
|
self.assertEqual(outputs, [])
|
||||||
|
|
||||||
# TODO: Enable this once hf-internal-testing/tiny-random-donut is implemented
|
# TODO: Enable this once hf-internal-testing/tiny-random-donut is implemented
|
||||||
# @require_torch
|
# @require_torch
|
||||||
# def test_small_model_pt_donut(self):
|
# def test_small_model_pt_donut(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user