From 8788fd0ceb2cea986cd4ca14b4eb4be554e1404c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 25 Jan 2023 15:46:10 +0100 Subject: [PATCH] Moving to cleaner tokenizer version or `oneformer`. (#21292) Moving to cleaner tokenizer version. --- .../oneformer/image_processing_oneformer.py | 17 ++- src/transformers/pipelines/__init__.py | 2 +- .../pipelines/image_segmentation.py | 21 +++- .../test_pipelines_image_segmentation.py | 102 ++++++++++++++++++ 4 files changed, 134 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/oneformer/image_processing_oneformer.py b/src/transformers/models/oneformer/image_processing_oneformer.py index 6fd6a9ca48..2a679b95c8 100644 --- a/src/transformers/models/oneformer/image_processing_oneformer.py +++ b/src/transformers/models/oneformer/image_processing_oneformer.py @@ -518,8 +518,8 @@ class OneFormerImageProcessor(BaseImageProcessor): reduce_labels=reduce_labels, ) - def __call__(self, images, task_inputs, segmentation_maps=None, **kwargs) -> BatchFeature: - return self.preprocess(images, task_inputs, segmentation_maps=segmentation_maps, **kwargs) + def __call__(self, images, task_inputs=None, segmentation_maps=None, **kwargs) -> BatchFeature: + return self.preprocess(images, task_inputs=task_inputs, segmentation_maps=segmentation_maps, **kwargs) def _preprocess( self, @@ -604,7 +604,7 @@ class OneFormerImageProcessor(BaseImageProcessor): def preprocess( self, images: ImageInput, - task_inputs: List[str], + task_inputs: Optional[List[str]] = None, segmentation_maps: Optional[ImageInput] = None, instance_id_to_semantic_id: Optional[Dict[int, int]] = None, do_resize: Optional[bool] = None, @@ -639,6 +639,10 @@ class OneFormerImageProcessor(BaseImageProcessor): ) do_reduce_labels = kwargs.pop("reduce_labels") + if task_inputs is None: + # Default value + task_inputs = ["panoptic"] + do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False, max_size=self._max_size) @@ -973,8 +977,10 @@ class OneFormerImageProcessor(BaseImageProcessor): classes, masks, texts = self.get_semantic_annotations(label, num_class_obj) elif task == "instance": classes, masks, texts = self.get_instance_annotations(label, num_class_obj) - if task == "panoptic": + elif task == "panoptic": classes, masks, texts = self.get_panoptic_annotations(label, num_class_obj) + else: + raise ValueError(f"{task} was not expected, expected `semantic`, `instance` or `panoptic`") # we cannot batch them since they don't share a common class size masks = [mask[None, ...] for mask in masks] @@ -990,6 +996,9 @@ class OneFormerImageProcessor(BaseImageProcessor): encoded_inputs["class_labels"] = class_labels encoded_inputs["text_inputs"] = text_inputs + # This needs to be tokenized before sending to the model. + encoded_inputs["task_inputs"] = [f"the task is {task_input}" for task_input in task_inputs] + return encoded_inputs # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_semantic_segmentation diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 992f14f26d..e14d744579 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -331,7 +331,7 @@ SUPPORTED_TASKS = { "tf": (), "pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (), "default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}}, - "type": "image", + "type": "multimodal", }, "image-to-text": { "impl": ImageToTextPipeline, diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py index 5be5b858dc..4e98fe8cbf 100644 --- a/src/transformers/pipelines/image_segmentation.py +++ b/src/transformers/pipelines/image_segmentation.py @@ -87,9 +87,11 @@ class ImageSegmentationPipeline(Pipeline): ) def _sanitize_parameters(self, **kwargs): + preprocessor_kwargs = {} postprocess_kwargs = {} if "subtask" in kwargs: postprocess_kwargs["subtask"] = kwargs["subtask"] + preprocessor_kwargs["subtask"] = kwargs["subtask"] if "threshold" in kwargs: postprocess_kwargs["threshold"] = kwargs["threshold"] if "mask_threshold" in kwargs: @@ -97,7 +99,7 @@ class ImageSegmentationPipeline(Pipeline): if "overlap_mask_area_threshold" in kwargs: postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"] - return {}, {}, postprocess_kwargs + return preprocessor_kwargs, {}, postprocess_kwargs def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]: """ @@ -140,10 +142,23 @@ class ImageSegmentationPipeline(Pipeline): """ return super().__call__(images, **kwargs) - def preprocess(self, image): + def preprocess(self, image, subtask=None): image = load_image(image) target_size = [(image.height, image.width)] - inputs = self.image_processor(images=[image], return_tensors="pt") + if self.model.config.__class__.__name__ == "OneFormerConfig": + if subtask is None: + kwargs = {} + else: + kwargs = {"task_inputs": [subtask]} + inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs) + inputs["task_inputs"] = self.tokenizer( + inputs["task_inputs"], + padding="max_length", + max_length=self.model.config.task_seq_len, + return_tensors=self.framework, + )["input_ids"] + else: + inputs = self.image_processor(images=[image], return_tensors="pt") inputs["target_size"] = target_size return inputs diff --git a/tests/pipelines/test_pipelines_image_segmentation.py b/tests/pipelines/test_pipelines_image_segmentation.py index 8f022e68da..008c60a990 100644 --- a/tests/pipelines/test_pipelines_image_segmentation.py +++ b/tests/pipelines/test_pipelines_image_segmentation.py @@ -609,3 +609,105 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa }, ], ) + + @require_torch + @slow + def test_oneformer(self): + image_segmenter = pipeline(model="shi-labs/oneformer_ade20k_swin_tiny") + + image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + file = image[0]["file"] + outputs = image_segmenter(file, threshold=0.99) + # Shortening by hashing + for o in outputs: + o["mask"] = mask_to_test_readable(o["mask"]) + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + { + "score": 0.9981, + "label": "grass", + "mask": {"hash": "3a92904d4c", "white_pixels": 118131, "shape": (512, 683)}, + }, + { + "score": 0.9992, + "label": "sky", + "mask": {"hash": "fa2300cc9a", "white_pixels": 231565, "shape": (512, 683)}, + }, + ], + ) + + # Different task + outputs = image_segmenter(file, threshold=0.99, subtask="instance") + # Shortening by hashing + for o in outputs: + o["mask"] = mask_to_test_readable(o["mask"]) + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + { + "score": 0.9991, + "label": "sky", + "mask": {"hash": "8b1ffad016", "white_pixels": 230566, "shape": (512, 683)}, + }, + { + "score": 0.9981, + "label": "grass", + "mask": {"hash": "9bbdf83d3d", "white_pixels": 119130, "shape": (512, 683)}, + }, + ], + ) + + # Different task + outputs = image_segmenter(file, subtask="semantic") + # Shortening by hashing + for o in outputs: + o["mask"] = mask_to_test_readable(o["mask"]) + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + { + "score": None, + "label": "wall", + "mask": {"hash": "897fb20b7f", "white_pixels": 14506, "shape": (512, 683)}, + }, + { + "score": None, + "label": "building", + "mask": {"hash": "f2a68c63e4", "white_pixels": 125019, "shape": (512, 683)}, + }, + { + "score": None, + "label": "sky", + "mask": {"hash": "e0ca3a548e", "white_pixels": 135330, "shape": (512, 683)}, + }, + { + "score": None, + "label": "tree", + "mask": {"hash": "7c9544bcac", "white_pixels": 16263, "shape": (512, 683)}, + }, + { + "score": None, + "label": "road, route", + "mask": {"hash": "2c7704e491", "white_pixels": 2143, "shape": (512, 683)}, + }, + { + "score": None, + "label": "grass", + "mask": {"hash": "bf6c2867e0", "white_pixels": 53040, "shape": (512, 683)}, + }, + { + "score": None, + "label": "plant", + "mask": {"hash": "93c4b7199e", "white_pixels": 3335, "shape": (512, 683)}, + }, + { + "score": None, + "label": "house", + "mask": {"hash": "93ec419ad5", "white_pixels": 60, "shape": (512, 683)}, + }, + ], + )