From d066c3731bed1755f93ea64f0f00981b805532de Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 10 Nov 2022 11:33:38 +0100 Subject: [PATCH] Adding support for LayoutLMvX variants for `object-detection`. (#20143) * Adding support for LayoutLMvX variants for `object-detection`. * Revert bogs `layoutlm` feature extractor which does not exist (it was a V2 model) . * Updated condition. * Handling the comments. --- src/transformers/pipelines/__init__.py | 2 +- .../pipelines/object_detection.py | 64 ++++++++++++++----- .../test_pipelines_object_detection.py | 27 ++++++++ 3 files changed, 76 insertions(+), 17 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index ee7b495627..1beda27ce8 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -345,7 +345,7 @@ SUPPORTED_TASKS = { "tf": (), "pt": (AutoModelForObjectDetection,) if is_torch_available() else (), "default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}}, - "type": "image", + "type": "multimodal", }, "zero-shot-object-detection": { "impl": ZeroShotObjectDetectionPipeline, diff --git a/src/transformers/pipelines/object_detection.py b/src/transformers/pipelines/object_detection.py index 7cf034b7c8..018ea9e63e 100644 --- a/src/transformers/pipelines/object_detection.py +++ b/src/transformers/pipelines/object_detection.py @@ -11,7 +11,7 @@ if is_vision_available(): if is_torch_available(): import torch - from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING + from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING logger = logging.get_logger(__name__) @@ -39,7 +39,9 @@ class ObjectDetectionPipeline(Pipeline): raise ValueError(f"The {self.__class__} is only available in PyTorch.") requires_backends(self, "vision") - self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING) + self.check_model_type( + dict(MODEL_FOR_OBJECT_DETECTION_MAPPING.items() + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items()) + ) def _sanitize_parameters(self, **kwargs): postprocess_kwargs = {} @@ -82,6 +84,8 @@ class ObjectDetectionPipeline(Pipeline): image = load_image(image) target_size = torch.IntTensor([[image.height, image.width]]) inputs = self.feature_extractor(images=[image], return_tensors="pt") + if self.tokenizer is not None: + inputs = self.tokenizer(text=inputs["words"], boxes=inputs["boxes"], return_tensors="pt") inputs["target_size"] = target_size return inputs @@ -89,26 +93,54 @@ class ObjectDetectionPipeline(Pipeline): target_size = model_inputs.pop("target_size") outputs = self.model(**model_inputs) model_outputs = outputs.__class__({"target_size": target_size, **outputs}) + if self.tokenizer is not None: + model_outputs["bbox"] = model_inputs["bbox"] return model_outputs def postprocess(self, model_outputs, threshold=0.9): target_size = model_outputs["target_size"] - raw_annotations = self.feature_extractor.post_process_object_detection(model_outputs, threshold, target_size) - raw_annotation = raw_annotations[0] - scores = raw_annotation["scores"] - labels = raw_annotation["labels"] - boxes = raw_annotation["boxes"] + if self.tokenizer is not None: + # This is a LayoutLMForTokenClassification variant. + # The OCR got the boxes and the model classified the words. + width, height = target_size[0].tolist() - raw_annotation["scores"] = scores.tolist() - raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels] - raw_annotation["boxes"] = [self._get_bounding_box(box) for box in boxes] + def unnormalize(bbox): + return self._get_bounding_box( + torch.Tensor( + [ + (width * bbox[0] / 1000), + (height * bbox[1] / 1000), + (width * bbox[2] / 1000), + (height * bbox[3] / 1000), + ] + ) + ) - # {"scores": [...], ...} --> [{"score":x, ...}, ...] - keys = ["score", "label", "box"] - annotation = [ - dict(zip(keys, vals)) - for vals in zip(raw_annotation["scores"], raw_annotation["labels"], raw_annotation["boxes"]) - ] + scores, classes = model_outputs["logits"].squeeze(0).softmax(dim=-1).max(dim=-1) + labels = [self.model.config.id2label[prediction] for prediction in classes.tolist()] + boxes = [unnormalize(bbox) for bbox in model_outputs["bbox"].squeeze(0)] + keys = ["score", "label", "box"] + annotation = [dict(zip(keys, vals)) for vals in zip(scores.tolist(), labels, boxes) if vals[0] > threshold] + else: + # This is a regular ForObjectDetectionModel + raw_annotations = self.feature_extractor.post_process_object_detection( + model_outputs, threshold, target_size + ) + raw_annotation = raw_annotations[0] + scores = raw_annotation["scores"] + labels = raw_annotation["labels"] + boxes = raw_annotation["boxes"] + + raw_annotation["scores"] = scores.tolist() + raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels] + raw_annotation["boxes"] = [self._get_bounding_box(box) for box in boxes] + + # {"scores": [...], ...} --> [{"score":x, ...}, ...] + keys = ["score", "label", "box"] + annotation = [ + dict(zip(keys, vals)) + for vals in zip(raw_annotation["scores"], raw_annotation["labels"], raw_annotation["boxes"]) + ] return annotation diff --git a/tests/pipelines/test_pipelines_object_detection.py b/tests/pipelines/test_pipelines_object_detection.py index 196f4c82ac..23a6dab299 100644 --- a/tests/pipelines/test_pipelines_object_detection.py +++ b/tests/pipelines/test_pipelines_object_detection.py @@ -243,3 +243,30 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase {"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}}, ], ) + + @require_torch + @slow + def test_layoutlm(self): + model_id = "philschmid/layoutlm-funsd" + threshold = 0.998 + + object_detector = pipeline("object-detection", model=model_id, threshold=threshold) + + outputs = object_detector( + "https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png" + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + { + "score": 0.9982, + "label": "B-QUESTION", + "box": {"xmin": 654, "ymin": 165, "xmax": 719, "ymax": 719}, + }, + { + "score": 0.9982, + "label": "I-QUESTION", + "box": {"xmin": 691, "ymin": 202, "xmax": 735, "ymax": 735}, + }, + ], + )