Fix object detection2 (#20798)

* Revert "Fixing object detection with `layoutlm` (#20776)"

This reverts commit fca66abe2a.

* Better fix for layoutlm object detection.

* Style.
This commit is contained in:
Nicolas Patry
2022-12-16 13:25:36 +01:00
committed by GitHub
parent 4341f4e224
commit 3ee958207a
2 changed files with 6 additions and 11 deletions

View File

@@ -378,12 +378,7 @@ NO_TOKENIZER_TASKS = set()
# any tokenizer/feature_extractor might be use for a given model so we cannot # any tokenizer/feature_extractor might be use for a given model so we cannot
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to # use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
# see if the model defines such objects or not. # see if the model defines such objects or not.
MULTI_MODEL_CONFIGS = { MULTI_MODEL_CONFIGS = {"SpeechEncoderDecoderConfig", "VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig"}
"SpeechEncoderDecoderConfig",
"VisionEncoderDecoderConfig",
"VisionTextDualEncoderConfig",
"LayoutLMConfig",
}
for task, values in SUPPORTED_TASKS.items(): for task, values in SUPPORTED_TASKS.items():
if values["type"] == "text": if values["type"] == "text":
NO_FEATURE_EXTRACTOR_TASKS.add(task) NO_FEATURE_EXTRACTOR_TASKS.add(task)

View File

@@ -247,8 +247,8 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
@require_torch @require_torch
@slow @slow
def test_layoutlm(self): def test_layoutlm(self):
model_id = "philschmid/layoutlm-funsd" model_id = "Narsil/layoutlmv3-finetuned-funsd"
threshold = 0.998 threshold = 0.9993
object_detector = pipeline("object-detection", model=model_id, threshold=threshold) object_detector = pipeline("object-detection", model=model_id, threshold=threshold)
@@ -256,9 +256,9 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
"https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png" "https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png"
) )
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=3), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.998, "label": "B-QUESTION", "box": {"xmin": 462, "ymin": 234, "xmax": 508, "ymax": 249}}, {"score": 0.9993, "label": "I-ANSWER", "box": {"xmin": 294, "ymin": 254, "xmax": 343, "ymax": 264}},
{"score": 0.999, "label": "I-QUESTION", "box": {"xmin": 489, "ymin": 286, "xmax": 519, "ymax": 301}}, {"score": 0.9993, "label": "I-ANSWER", "box": {"xmin": 294, "ymin": 254, "xmax": 343, "ymax": 264}},
], ],
) )