Adding support for LayoutLMvX variants for object-detection. (#20143)
* Adding support for LayoutLMvX variants for `object-detection`. * Revert bogs `layoutlm` feature extractor which does not exist (it was a V2 model) . * Updated condition. * Handling the comments.
This commit is contained in:
@@ -345,7 +345,7 @@ SUPPORTED_TASKS = {
|
|||||||
"tf": (),
|
"tf": (),
|
||||||
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
|
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
|
||||||
"default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
|
"default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
|
||||||
"type": "image",
|
"type": "multimodal",
|
||||||
},
|
},
|
||||||
"zero-shot-object-detection": {
|
"zero-shot-object-detection": {
|
||||||
"impl": ZeroShotObjectDetectionPipeline,
|
"impl": ZeroShotObjectDetectionPipeline,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ if is_vision_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING
|
from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@@ -39,7 +39,9 @@ class ObjectDetectionPipeline(Pipeline):
|
|||||||
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||||
|
|
||||||
requires_backends(self, "vision")
|
requires_backends(self, "vision")
|
||||||
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING)
|
self.check_model_type(
|
||||||
|
dict(MODEL_FOR_OBJECT_DETECTION_MAPPING.items() + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items())
|
||||||
|
)
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs):
|
def _sanitize_parameters(self, **kwargs):
|
||||||
postprocess_kwargs = {}
|
postprocess_kwargs = {}
|
||||||
@@ -82,6 +84,8 @@ class ObjectDetectionPipeline(Pipeline):
|
|||||||
image = load_image(image)
|
image = load_image(image)
|
||||||
target_size = torch.IntTensor([[image.height, image.width]])
|
target_size = torch.IntTensor([[image.height, image.width]])
|
||||||
inputs = self.feature_extractor(images=[image], return_tensors="pt")
|
inputs = self.feature_extractor(images=[image], return_tensors="pt")
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
inputs = self.tokenizer(text=inputs["words"], boxes=inputs["boxes"], return_tensors="pt")
|
||||||
inputs["target_size"] = target_size
|
inputs["target_size"] = target_size
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@@ -89,11 +93,39 @@ class ObjectDetectionPipeline(Pipeline):
|
|||||||
target_size = model_inputs.pop("target_size")
|
target_size = model_inputs.pop("target_size")
|
||||||
outputs = self.model(**model_inputs)
|
outputs = self.model(**model_inputs)
|
||||||
model_outputs = outputs.__class__({"target_size": target_size, **outputs})
|
model_outputs = outputs.__class__({"target_size": target_size, **outputs})
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
model_outputs["bbox"] = model_inputs["bbox"]
|
||||||
return model_outputs
|
return model_outputs
|
||||||
|
|
||||||
def postprocess(self, model_outputs, threshold=0.9):
|
def postprocess(self, model_outputs, threshold=0.9):
|
||||||
target_size = model_outputs["target_size"]
|
target_size = model_outputs["target_size"]
|
||||||
raw_annotations = self.feature_extractor.post_process_object_detection(model_outputs, threshold, target_size)
|
if self.tokenizer is not None:
|
||||||
|
# This is a LayoutLMForTokenClassification variant.
|
||||||
|
# The OCR got the boxes and the model classified the words.
|
||||||
|
width, height = target_size[0].tolist()
|
||||||
|
|
||||||
|
def unnormalize(bbox):
|
||||||
|
return self._get_bounding_box(
|
||||||
|
torch.Tensor(
|
||||||
|
[
|
||||||
|
(width * bbox[0] / 1000),
|
||||||
|
(height * bbox[1] / 1000),
|
||||||
|
(width * bbox[2] / 1000),
|
||||||
|
(height * bbox[3] / 1000),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
scores, classes = model_outputs["logits"].squeeze(0).softmax(dim=-1).max(dim=-1)
|
||||||
|
labels = [self.model.config.id2label[prediction] for prediction in classes.tolist()]
|
||||||
|
boxes = [unnormalize(bbox) for bbox in model_outputs["bbox"].squeeze(0)]
|
||||||
|
keys = ["score", "label", "box"]
|
||||||
|
annotation = [dict(zip(keys, vals)) for vals in zip(scores.tolist(), labels, boxes) if vals[0] > threshold]
|
||||||
|
else:
|
||||||
|
# This is a regular ForObjectDetectionModel
|
||||||
|
raw_annotations = self.feature_extractor.post_process_object_detection(
|
||||||
|
model_outputs, threshold, target_size
|
||||||
|
)
|
||||||
raw_annotation = raw_annotations[0]
|
raw_annotation = raw_annotations[0]
|
||||||
scores = raw_annotation["scores"]
|
scores = raw_annotation["scores"]
|
||||||
labels = raw_annotation["labels"]
|
labels = raw_annotation["labels"]
|
||||||
|
|||||||
@@ -243,3 +243,30 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
|
|||||||
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
|
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_layoutlm(self):
|
||||||
|
model_id = "philschmid/layoutlm-funsd"
|
||||||
|
threshold = 0.998
|
||||||
|
|
||||||
|
object_detector = pipeline("object-detection", model=model_id, threshold=threshold)
|
||||||
|
|
||||||
|
outputs = object_detector(
|
||||||
|
"https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"score": 0.9982,
|
||||||
|
"label": "B-QUESTION",
|
||||||
|
"box": {"xmin": 654, "ymin": 165, "xmax": 719, "ymax": 719},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"score": 0.9982,
|
||||||
|
"label": "I-QUESTION",
|
||||||
|
"box": {"xmin": 691, "ymin": 202, "xmax": 735, "ymax": 735},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user