Add DocumentQuestionAnswering pipeline (#18414)
* [WIP] Skeleton of VisualQuestionAnweringPipeline extended to support LayoutLM-like models * Fixup * Use the full encoding * Basic refactoring to DocumentQuestionAnsweringPipeline * Cleanup * Improve args, docs, and implement preprocessing * Integrate OCR * Refactor question_answering pipeline * Use refactored QA code in the document qa pipeline * Fix tests * Some small cleanups * Use a string type annotation for Image.Image * Update encoding with image features * Wire through the basic docs * Handle invalid response * Handle empty word_boxes properly * Docstring fix * Integrate Donut model * Fixup * Incorporate comments * Address comments * Initial incorporation of tests * Address Comments * Change assert to ValueError * Comments * Wrap `score` in float to make it JSON serializable * Incorporate AutoModeLForDocumentQuestionAnswering changes * Fixup * Rename postprocess function * Fix auto import * Applying comments * Improve docs * Remove extra assets and add copyright * Address comments Co-authored-by: Ankur Goyal <ankur@impira.com>
This commit is contained in:
@@ -89,6 +89,7 @@ if is_torch_available():
|
||||
MODEL_FOR_AUDIO_XVECTOR_MAPPING,
|
||||
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
@@ -172,7 +173,10 @@ class ModelTesterMixin:
|
||||
if return_labels:
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
||||
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
elif model_class in [
|
||||
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING),
|
||||
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING),
|
||||
]:
|
||||
inputs_dict["start_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
@@ -542,7 +546,10 @@ class ModelTesterMixin:
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
if model_class in [
|
||||
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING),
|
||||
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING),
|
||||
]:
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
|
||||
Reference in New Issue
Block a user