Implement multiple span support for DocumentQuestionAnswering (#19204)
* Implement multiple span support * Address comments * Add tests + fix bugs
This commit is contained in:
@@ -191,6 +191,52 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
||||
* 2,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_detectron2
|
||||
@require_pytesseract
|
||||
def test_large_model_pt_chunk(self):
|
||||
dqa_pipeline = pipeline(
|
||||
"document-question-answering",
|
||||
model="tiennvcs/layoutlmv2-base-uncased-finetuned-docvqa",
|
||||
revision="9977165",
|
||||
max_seq_len=50,
|
||||
)
|
||||
image = INVOICE_URL
|
||||
question = "What is the invoice number?"
|
||||
|
||||
outputs = dqa_pipeline(image=image, question=question, top_k=2)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
|
||||
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = dqa_pipeline({"image": image, "question": question}, top_k=2)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
|
||||
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = dqa_pipeline(
|
||||
[{"image": image, "question": question}, {"image": image, "question": question}], top_k=2
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
|
||||
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
|
||||
]
|
||||
]
|
||||
* 2,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_pytesseract
|
||||
@@ -252,6 +298,59 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_pytesseract
|
||||
@require_vision
|
||||
def test_large_model_pt_layoutlm_chunk(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"impira/layoutlm-document-qa", revision="3dc6de3", add_prefix_space=True
|
||||
)
|
||||
dqa_pipeline = pipeline(
|
||||
"document-question-answering",
|
||||
model="impira/layoutlm-document-qa",
|
||||
tokenizer=tokenizer,
|
||||
revision="3dc6de3",
|
||||
max_seq_len=50,
|
||||
)
|
||||
image = INVOICE_URL
|
||||
question = "What is the invoice number?"
|
||||
|
||||
outputs = dqa_pipeline(image=image, question=question, top_k=2)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = dqa_pipeline(
|
||||
[{"image": image, "question": question}, {"image": image, "question": question}], top_k=2
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
|
||||
]
|
||||
]
|
||||
* 2,
|
||||
)
|
||||
|
||||
word_boxes = list(zip(*apply_tesseract(load_image(image), None, "")))
|
||||
|
||||
# This model should also work if `image` is set to None
|
||||
outputs = dqa_pipeline({"image": None, "word_boxes": word_boxes, "question": question}, top_k=2)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt_donut(self):
|
||||
|
||||
Reference in New Issue
Block a user