A few CI fixes for DocumentQuestionAnsweringPipeline (#19584)
* Fixes * update expected values * style * fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -235,7 +235,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
`word_boxes`).
|
`word_boxes`).
|
||||||
- **answer** (`str`) -- The answer to the question.
|
- **answer** (`str`) -- The answer to the question.
|
||||||
- **words** (`list[int]`) -- The index of each word/box pair that is in the answer
|
- **words** (`list[int]`) -- The index of each word/box pair that is in the answer
|
||||||
- **page** (`int`) -- The page of the answer
|
|
||||||
"""
|
"""
|
||||||
if isinstance(question, str):
|
if isinstance(question, str):
|
||||||
inputs = {"question": question, "image": image}
|
inputs = {"question": question, "image": image}
|
||||||
@@ -315,7 +314,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
"p_mask": None,
|
"p_mask": None,
|
||||||
"word_ids": None,
|
"word_ids": None,
|
||||||
"words": None,
|
"words": None,
|
||||||
"page": None,
|
|
||||||
"output_attentions": True,
|
"output_attentions": True,
|
||||||
"is_last": True,
|
"is_last": True,
|
||||||
}
|
}
|
||||||
@@ -339,6 +337,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
return_overflowing_tokens=True,
|
return_overflowing_tokens=True,
|
||||||
**tokenizer_kwargs,
|
**tokenizer_kwargs,
|
||||||
)
|
)
|
||||||
|
encoding.pop("overflow_to_sample_mapping") # We do not use this
|
||||||
|
|
||||||
num_spans = len(encoding["input_ids"])
|
num_spans = len(encoding["input_ids"])
|
||||||
|
|
||||||
@@ -395,9 +394,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
words = model_inputs.pop("words", None)
|
words = model_inputs.pop("words", None)
|
||||||
is_last = model_inputs.pop("is_last", False)
|
is_last = model_inputs.pop("is_last", False)
|
||||||
|
|
||||||
if "overflow_to_sample_mapping" in model_inputs:
|
|
||||||
model_inputs.pop("overflow_to_sample_mapping")
|
|
||||||
|
|
||||||
if self.model_type == ModelType.VisionEncoderDecoder:
|
if self.model_type == ModelType.VisionEncoderDecoder:
|
||||||
model_outputs = self.model.generate(**model_inputs)
|
model_outputs = self.model.generate(**model_inputs)
|
||||||
else:
|
else:
|
||||||
@@ -421,7 +417,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
return answers
|
return answers
|
||||||
|
|
||||||
def postprocess_encoder_decoder_single(self, model_outputs, **kwargs):
|
def postprocess_encoder_decoder_single(self, model_outputs, **kwargs):
|
||||||
sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0]
|
sequence = self.tokenizer.batch_decode(model_outputs["sequences"])[0]
|
||||||
|
|
||||||
# TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer
|
# TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer
|
||||||
# (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context).
|
# (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context).
|
||||||
|
|||||||
@@ -209,8 +209,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(outputs, decimals=4),
|
nested_simplify(outputs, decimals=4),
|
||||||
[
|
[
|
||||||
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
|
{"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
|
||||||
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
|
{"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -218,8 +218,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(outputs, decimals=4),
|
nested_simplify(outputs, decimals=4),
|
||||||
[
|
[
|
||||||
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
|
{"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
|
||||||
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
|
{"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -230,8 +230,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
|||||||
nested_simplify(outputs, decimals=4),
|
nested_simplify(outputs, decimals=4),
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
|
{"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
|
||||||
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
|
{"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
* 2,
|
* 2,
|
||||||
@@ -320,8 +320,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(outputs, decimals=4),
|
nested_simplify(outputs, decimals=4),
|
||||||
[
|
[
|
||||||
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
|
{"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
|
||||||
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
|
{"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -332,8 +332,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
|||||||
nested_simplify(outputs, decimals=4),
|
nested_simplify(outputs, decimals=4),
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
|
{"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
|
||||||
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
|
{"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
* 2,
|
* 2,
|
||||||
@@ -346,8 +346,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(outputs, decimals=4),
|
nested_simplify(outputs, decimals=4),
|
||||||
[
|
[
|
||||||
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
|
{"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
|
||||||
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
|
{"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user