diff --git a/docs/source/en/model_doc/mask2former.md b/docs/source/en/model_doc/mask2former.md index f27fd5948f..04968e27e3 100644 --- a/docs/source/en/model_doc/mask2former.md +++ b/docs/source/en/model_doc/mask2former.md @@ -77,4 +77,12 @@ The resource should ideally demonstrate something new instead of duplicating an - encode_inputs - post_process_semantic_segmentation - post_process_instance_segmentation + - post_process_panoptic_segmentation + +## Mask2FormerImageProcessorFast + +[[autodoc]] Mask2FormerImageProcessorFast + - preprocess + - post_process_semantic_segmentation + - post_process_instance_segmentation - post_process_panoptic_segmentation \ No newline at end of file diff --git a/docs/source/en/model_doc/maskformer.md b/docs/source/en/model_doc/maskformer.md index fcfe11ec55..cd84cd9ffd 100644 --- a/docs/source/en/model_doc/maskformer.md +++ b/docs/source/en/model_doc/maskformer.md @@ -76,6 +76,14 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The - post_process_instance_segmentation - post_process_panoptic_segmentation +## MaskFormerImageProcessorFast + +[[autodoc]] MaskFormerImageProcessorFast + - preprocess + - post_process_semantic_segmentation + - post_process_instance_segmentation + - post_process_panoptic_segmentation + ## MaskFormerFeatureExtractor [[autodoc]] MaskFormerFeatureExtractor diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index b3f4658fdd..f169cf413a 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -118,8 +118,8 @@ else: ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")), ("llava_next_video", ("LlavaNextVideoImageProcessor",)), ("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")), - ("mask2former", ("Mask2FormerImageProcessor",)), - ("maskformer", ("MaskFormerImageProcessor",)), + ("mask2former", ("Mask2FormerImageProcessor", "Mask2FormerImageProcessorFast")), + ("maskformer", ("MaskFormerImageProcessor", "MaskFormerImageProcessorFast")), ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")), ("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")), ("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")), diff --git a/src/transformers/models/mask2former/__init__.py b/src/transformers/models/mask2former/__init__.py index 752787e79a..a4281e77df 100644 --- a/src/transformers/models/mask2former/__init__.py +++ b/src/transformers/models/mask2former/__init__.py @@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure if TYPE_CHECKING: from .configuration_mask2former import * from .image_processing_mask2former import * + from .image_processing_mask2former_fast import * from .modeling_mask2former import * else: import sys diff --git a/src/transformers/models/mask2former/image_processing_mask2former.py b/src/transformers/models/mask2former/image_processing_mask2former.py index ef4f4701e7..20610e7c77 100644 --- a/src/transformers/models/mask2former/image_processing_mask2former.py +++ b/src/transformers/models/mask2former/image_processing_mask2former.py @@ -35,8 +35,8 @@ from ...image_utils import ( PILImageResampling, get_image_size, infer_channel_dimension_format, - is_batched, is_scaled_image, + make_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, @@ -61,6 +61,46 @@ if is_torch_available(): from torch import nn +# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = int(raw_size * height / width) + else: + oh = int(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = int(raw_size * width / height) + else: + ow = int(size * width / height) + + return (oh, ow) + + # Copied from transformers.models.detr.image_processing_detr.max_across_indices def max_across_indices(values: Iterable[Any]) -> list[Any]: """ @@ -394,6 +434,10 @@ class Mask2FormerImageProcessor(BaseImageProcessor): The background label will be replaced by `ignore_index`. num_labels (`int`, *optional*): The number of labels in the segmentation map. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. """ model_input_names = ["pixel_values", "pixel_mask"] @@ -416,6 +460,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ignore_index: Optional[int] = None, do_reduce_labels: bool = False, num_labels: Optional[int] = None, + pad_size: Optional[dict[str, int]] = None, **kwargs, ): super().__init__(**kwargs) @@ -439,6 +484,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): self.ignore_index = ignore_index self.do_reduce_labels = do_reduce_labels self.num_labels = num_labels + self.pad_size = pad_size @classmethod def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): @@ -697,6 +743,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): return_tensors: Optional[Union[str, TensorType]] = None, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[dict[str, int]] = None, ) -> BatchFeature: do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size @@ -710,6 +757,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): image_std = image_std if image_std is not None else self.image_std ignore_index = ignore_index if ignore_index is not None else self.ignore_index do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels + pad_size = self.pad_size if pad_size is None else pad_size if not valid_images(images): raise ValueError( @@ -734,9 +782,9 @@ class Mask2FormerImageProcessor(BaseImageProcessor): "torch.Tensor, tf.Tensor or jax.ndarray." ) - if not is_batched(images): - images = [images] - segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None + images = make_list_of_images(images) + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) if segmentation_maps is not None and len(images) != len(segmentation_maps): raise ValueError("Images and segmentation maps must have the same length.") @@ -774,6 +822,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): do_reduce_labels, return_tensors, input_data_format=data_format, + pad_size=pad_size, ) return encoded_inputs @@ -805,7 +854,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): ) return padded_image - # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.pad + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.pad def pad( self, images: list[np.ndarray], @@ -814,6 +863,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[dict[str, int]] = None, ) -> BatchFeature: """ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width @@ -837,13 +887,21 @@ class Mask2FormerImageProcessor(BaseImageProcessor): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. """ - pad_size = get_max_height_width(images, input_data_format=input_data_format) + pad_size = pad_size if pad_size is not None else self.pad_size + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images, input_data_format=input_data_format) padded_images = [ self._pad_image( image, - pad_size, + padded_size, constant_values=constant_values, data_format=data_format, input_data_format=input_data_format, @@ -854,7 +912,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): if return_pixel_mask: masks = [ - make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format) + make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format) for image in images ] data["pixel_mask"] = masks @@ -870,6 +928,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): do_reduce_labels: bool = False, return_tensors: Optional[Union[str, TensorType]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[dict[str, int]] = None, ): """ Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. @@ -909,6 +968,11 @@ class Mask2FormerImageProcessor(BaseImageProcessor): input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -930,7 +994,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor): input_data_format = infer_channel_dimension_format(pixel_values_list[0]) encoded_inputs = self.pad( - pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format + pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format, pad_size=pad_size ) if segmentation_maps is not None: diff --git a/src/transformers/models/mask2former/image_processing_mask2former_fast.py b/src/transformers/models/mask2former/image_processing_mask2former_fast.py new file mode 100644 index 0000000000..1f61e9b0cd --- /dev/null +++ b/src/transformers/models/mask2former/image_processing_mask2former_fast.py @@ -0,0 +1,737 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mask2former/modular_mask2former.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mask2former.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Optional, Union + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + SizeDict, + get_image_size_for_max_height_width, + get_max_height_width, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from .image_processing_mask2former import ( + compute_segments, + convert_segmentation_to_rle, + get_size_with_aspect_ratio, + remove_low_and_no_objects, +) + + +if is_torch_available(): + import torch + from torch import nn + + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + +logger = logging.get_logger(__name__) + + +class Mask2FormerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + size_divisor (`int`, *optional*, defaults to 32): + Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in + Swin Transformer. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). + The background label will be replaced by `ignore_index`. + num_labels (`int`, *optional*): + The number of labels in the segmentation map. + do_pad (`bool`, *optional*, defaults to `True`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + + size_divisor: Optional[int] + ignore_index: Optional[int] + do_reduce_labels: Optional[bool] + num_labels: Optional[int] + do_pad: Optional[bool] + pad_size: Optional[dict[str, int]] + + +def convert_segmentation_map_to_binary_masks_fast( + segmentation_map: "torch.Tensor", + instance_id_to_semantic_id: Optional[dict[int, int]] = None, + ignore_index: Optional[int] = None, + do_reduce_labels: bool = False, +): + if do_reduce_labels and ignore_index is None: + raise ValueError("If `do_reduce_labels` is True, `ignore_index` must be provided.") + + if do_reduce_labels: + segmentation_map = torch.where(segmentation_map == 0, ignore_index, segmentation_map - 1) + + all_labels = torch.unique(segmentation_map) + + if ignore_index is not None: + all_labels = all_labels[all_labels != ignore_index] # drop background label if applicable + + binary_masks = [(segmentation_map == i) for i in all_labels] + if binary_masks: + binary_masks = torch.stack(binary_masks, dim=0) + else: + binary_masks = torch.zeros((0, *segmentation_map.shape), device=segmentation_map.device) + + # Convert instance ids to class ids + if instance_id_to_semantic_id is not None: + labels = torch.zeros(all_labels.shape[0], device=segmentation_map.device) + + for i, label in enumerate(all_labels): + class_id = instance_id_to_semantic_id[(label.item() + 1 if do_reduce_labels else label.item())] + labels[i] = class_id - 1 if do_reduce_labels else class_id + else: + labels = all_labels + return binary_masks.float(), labels.long() + + +@auto_docstring +class Mask2FormerImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"shortest_edge": 800, "longest_edge": 1333} + default_to_square = False + do_resize = True + do_rescale = True + rescale_factor = 1 / 255 + do_normalize = True + do_pad = True + model_input_names = ["pixel_values", "pixel_mask"] + size_divisor = 32 + do_reduce_labels = False + valid_kwargs = Mask2FormerFastImageProcessorKwargs + + @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.44.0") + @deprecate_kwarg("size_divisibility", new_name="size_divisor", version="4.41.0") + @deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True) + def __init__(self, **kwargs: Unpack[Mask2FormerFastImageProcessorKwargs]) -> None: + if "pad_and_return_pixel_mask" in kwargs: + kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask") + + size = kwargs.pop("size", None) + max_size = kwargs.pop("max_size", None) + + if size is None and max_size is not None: + size = self.size + size["longest_edge"] = max_size + elif size is None: + size = self.size + + self.size = get_size_dict(size, max_size=max_size, default_to_square=False) + + super().__init__(**kwargs) + + @classmethod + def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `Mask2FormerImageProcessor.from_pretrained(checkpoint, max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "size_divisibility" in kwargs: + image_processor_dict["size_divisor"] = kwargs.pop("size_divisibility") + if "reduce_labels" in image_processor_dict: + image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels") + return super().from_dict(image_processor_dict, **kwargs) + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. This method calls the superclass method and then removes the + `_max_size` attribute from the dictionary. + """ + image_processor_dict = super().to_dict() + image_processor_dict.pop("_max_size", None) + return image_processor_dict + + def reduce_label(self, labels: list["torch.Tensor"]): + for idx in range(len(labels)): + label = labels[idx] + label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label) + label = label - 1 + label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label) + labels[idx] = label + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + size_divisor: int = 0, + interpolation: "F.InterpolationMode" = None, + **kwargs, + ) -> torch.Tensor: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + size_divisor (`int`, *optional*, defaults to 0): + If `size_divisor` is given, the output image size will be divisible by the number. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + Resampling filter to use if resizing the image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if size.shortest_edge and size.longest_edge: + # Resize the image so that the shortest edge or the longest edge is of the given size + # while maintaining the aspect ratio of the original image. + new_size = get_size_with_aspect_ratio( + image.size()[-2:], + size["shortest_edge"], + size["longest_edge"], + ) + elif size.max_height and size.max_width: + new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"]) + elif size.height and size.width: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + if size_divisor > 0: + height, width = new_size + height = int(math.ceil(height / size_divisor) * size_divisor) + width = int(math.ceil(width / size_divisor) * size_divisor) + new_size = (height, width) + + image = F.resize( + image, + size=new_size, + interpolation=interpolation, + **kwargs, + ) + return image + + def pad( + self, + images: torch.Tensor, + padded_size: tuple[int, int], + segmentation_maps: Optional[torch.Tensor] = None, + fill: int = 0, + ignore_index: int = 255, + ) -> BatchFeature: + original_size = images.size()[-2:] + padding_bottom = padded_size[0] - original_size[0] + padding_right = padded_size[1] - original_size[1] + if padding_bottom < 0 or padding_right < 0: + raise ValueError( + f"Padding dimensions are negative. Please make sure that the padded size is larger than the " + f"original size. Got padded size: {padded_size}, original size: {original_size}." + ) + if original_size != padded_size: + padding = [0, 0, padding_right, padding_bottom] + images = F.pad(images, padding, fill=fill) + if segmentation_maps is not None: + segmentation_maps = F.pad(segmentation_maps, padding, fill=ignore_index) + + # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + pixel_mask = torch.zeros((images.shape[0], *padded_size), dtype=torch.int64, device=images.device) + pixel_mask[:, : original_size[0], : original_size[1]] = 1 + + return images, pixel_mask, segmentation_maps + + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]] = None, + **kwargs: Unpack[Mask2FormerFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps. + instance_id_to_semantic_id (`Union[list[dict[int, int]], dict[int, int]]`, *optional*): + A mapping from instance IDs to semantic IDs. + """ + return super().preprocess( + images, + segmentation_maps, + instance_id_to_semantic_id, + **kwargs, + ) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: ImageInput, + instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[Mask2FormerFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + To be overriden by subclasses when image-like inputs other than images should be processed. + It can be used for segmentation maps, depth maps, etc. + """ + # Prepare input images + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + if segmentation_maps is not None: + segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + return self._preprocess(images, segmentation_maps, instance_id_to_semantic_id, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + segmentation_maps: Optional["torch.Tensor"], + instance_id_to_semantic_id: Optional[dict[int, int]], + do_resize: Optional[bool], + size: Optional[dict[str, int]], + pad_size: Optional[dict[str, int]], + size_divisor: Optional[int], + interpolation: Optional[Union["PILImageResampling", "F.InterpolationMode"]], + do_rescale: Optional[bool], + rescale_factor: Optional[float], + do_normalize: Optional[bool], + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + ignore_index: Optional[int], + do_reduce_labels: Optional[bool], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + if segmentation_maps is not None and len(images) != len(segmentation_maps): + raise ValueError("Images and segmentation maps must have the same length.") + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + if segmentation_maps is not None: + grouped_segmentation_maps, grouped_segmentation_maps_index = group_images_by_shape( + segmentation_maps, disable_grouping=disable_grouping + ) + resized_segmentation_maps_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, size=size, size_divisor=size_divisor, interpolation=interpolation + ) + if segmentation_maps is not None: + stacked_segmentation_maps = self.resize( + image=grouped_segmentation_maps[shape], + size=size, + size_divisor=size_divisor, + interpolation=F.InterpolationMode.NEAREST_EXACT, + ) + resized_images_grouped[shape] = stacked_images + if segmentation_maps is not None: + resized_segmentation_maps_grouped[shape] = stacked_segmentation_maps + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + if segmentation_maps is not None: + resized_segmentation_maps = reorder_images( + resized_segmentation_maps_grouped, grouped_segmentation_maps_index + ) + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(resized_images) + + if segmentation_maps is not None: + mask_labels = [] + class_labels = [] + # Convert to list of binary masks and labels + for idx, segmentation_map in enumerate(resized_segmentation_maps): + if isinstance(instance_id_to_semantic_id, list): + instance_id = instance_id_to_semantic_id[idx] + else: + instance_id = instance_id_to_semantic_id + # Use instance2class_id mapping per image + masks, classes = convert_segmentation_map_to_binary_masks_fast( + segmentation_map.squeeze(0), + instance_id, + ignore_index=ignore_index, + do_reduce_labels=do_reduce_labels, + ) + mask_labels.append(masks) + class_labels.append(classes) + + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + processed_pixel_masks_grouped = {} + if segmentation_maps is not None: + grouped_segmentation_maps, grouped_segmentation_maps_index = group_images_by_shape( + mask_labels, disable_grouping=disable_grouping + ) + processed_segmentation_maps_grouped = {} + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + padded_images, pixel_masks, padded_segmentation_maps = self.pad( + images=stacked_images, + segmentation_maps=grouped_segmentation_maps[shape] if segmentation_maps is not None else None, + padded_size=padded_size, + ignore_index=ignore_index, + ) + processed_images_grouped[shape] = padded_images + processed_pixel_masks_grouped[shape] = pixel_masks + if segmentation_maps is not None: + processed_segmentation_maps_grouped[shape] = padded_segmentation_maps.squeeze(1) + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_pixel_masks = reorder_images(processed_pixel_masks_grouped, grouped_images_index) + encoded_inputs = BatchFeature( + data={ + "pixel_values": torch.stack(processed_images, dim=0) if return_tensors else processed_images, + "pixel_mask": torch.stack(processed_pixel_masks, dim=0) if return_tensors else processed_pixel_masks, + }, + tensor_type=return_tensors, + ) + if segmentation_maps is not None: + mask_labels = reorder_images(processed_segmentation_maps_grouped, grouped_segmentation_maps_index) + # we cannot batch them since they don't share a common class size + encoded_inputs["mask_labels"] = mask_labels + encoded_inputs["class_labels"] = class_labels + + return encoded_inputs + + def post_process_semantic_segmentation( + self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None + ) -> "torch.Tensor": + """ + Converts the output of [`Mask2FormerForUniversalSegmentation`] into semantic segmentation maps. Only supports + PyTorch. + + Args: + outputs ([`Mask2FormerForUniversalSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple[int, int]]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + Returns: + `List[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + # Scale back to preprocessed image size - (384, 384) for all models + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False + ) + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = torch.nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + def post_process_instance_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[list[tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + return_binary_maps: Optional[bool] = False, + ) -> list[dict]: + """ + Converts the output of [`Mask2FormerForUniversalSegmentationOutput`] into instance segmentation predictions. + Only supports PyTorch. If instances could overlap, set either return_coco_annotation or return_binary_maps + to `True` to get the correct segmentation result. + + Args: + outputs ([`Mask2FormerForUniversalSegmentation`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + return_coco_annotation (`bool`, *optional*, defaults to `False`): + If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format. + return_binary_maps (`bool`, *optional*, defaults to `False`): + If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps + (one per detected instance). + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id`, or + `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to + `True`, or a tensor of shape `(num_instances, height, width)` if return_binary_maps is set to `True`. + Set to `None` if no mask if found above `threshold`. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- An integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + if return_coco_annotation and return_binary_maps: + raise ValueError("return_coco_annotation and return_binary_maps can not be both set to True.") + + # [batch_size, num_queries, num_classes+1] + class_queries_logits = outputs.class_queries_logits + # [batch_size, num_queries, height, width] + masks_queries_logits = outputs.masks_queries_logits + + # Scale back to preprocessed image size - (384, 384) for all models + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False + ) + + device = masks_queries_logits.device + num_classes = class_queries_logits.shape[-1] - 1 + num_queries = class_queries_logits.shape[-2] + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(class_queries_logits.shape[0]): + mask_pred = masks_queries_logits[i] + mask_cls = class_queries_logits[i] + + scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1] + labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) + + scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False) + labels_per_image = labels[topk_indices] + + topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor") + mask_pred = mask_pred[topk_indices] + pred_masks = (mask_pred > 0).float() + + # Calculate average mask prob + mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / ( + pred_masks.flatten(1).sum(1) + 1e-6 + ) + pred_scores = scores_per_image * mask_scores_per_image + pred_classes = labels_per_image + + segmentation = torch.zeros((384, 384)) - 1 + if target_sizes is not None: + segmentation = torch.zeros(target_sizes[i]) - 1 + pred_masks = torch.nn.functional.interpolate( + pred_masks.unsqueeze(0), size=target_sizes[i], mode="nearest" + )[0] + + instance_maps, segments = [], [] + current_segment_id = 0 + for j in range(num_queries): + score = pred_scores[j].item() + + if not torch.all(pred_masks[j] == 0) and score >= threshold: + segmentation[pred_masks[j] == 1] = current_segment_id + segments.append( + { + "id": current_segment_id, + "label_id": pred_classes[j].item(), + "was_fused": False, + "score": round(score, 6), + } + ) + current_segment_id += 1 + instance_maps.append(pred_masks[j]) + + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + # Return a concatenated tensor of binary instance maps + if return_binary_maps and len(instance_maps) != 0: + segmentation = torch.stack(instance_maps, dim=0) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[set[int]] = None, + target_sizes: Optional[list[tuple[int, int]]] = None, + ) -> list[dict]: + """ + Converts the output of [`Mask2FormerForUniversalSegmentationOutput`] into image panoptic segmentation + predictions. Only supports PyTorch. + + Args: + outputs ([`Mask2FormerForUniversalSegmentationOutput`]): + The outputs from [`Mask2FormerForUniversalSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If left to None, predictions will not be + resized. + + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set + to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized + to the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + # Scale back to preprocessed image size - (384, 384) for all models + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False + ) + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["Mask2FormerImageProcessorFast"] diff --git a/src/transformers/models/mask2former/modular_mask2former.py b/src/transformers/models/mask2former/modular_mask2former.py new file mode 100644 index 0000000000..eed1e7adc9 --- /dev/null +++ b/src/transformers/models/mask2former/modular_mask2former.py @@ -0,0 +1,315 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from transformers.models.maskformer.image_processing_maskformer_fast import MaskFormerImageProcessorFast + +from ...utils import ( + TensorType, + is_torch_available, + logging, +) +from .image_processing_mask2former import ( + compute_segments, + convert_segmentation_to_rle, + remove_low_and_no_objects, +) + + +if is_torch_available(): + import torch + from torch import nn + + +logger = logging.get_logger(__name__) + + +class Mask2FormerImageProcessorFast(MaskFormerImageProcessorFast): + def post_process_semantic_segmentation( + self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None + ) -> "torch.Tensor": + """ + Converts the output of [`Mask2FormerForUniversalSegmentation`] into semantic segmentation maps. Only supports + PyTorch. + + Args: + outputs ([`Mask2FormerForUniversalSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple[int, int]]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + Returns: + `List[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + # Scale back to preprocessed image size - (384, 384) for all models + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False + ) + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = torch.nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + def post_process_instance_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[list[tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + return_binary_maps: Optional[bool] = False, + ) -> list[dict]: + """ + Converts the output of [`Mask2FormerForUniversalSegmentationOutput`] into instance segmentation predictions. + Only supports PyTorch. If instances could overlap, set either return_coco_annotation or return_binary_maps + to `True` to get the correct segmentation result. + + Args: + outputs ([`Mask2FormerForUniversalSegmentation`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + return_coco_annotation (`bool`, *optional*, defaults to `False`): + If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format. + return_binary_maps (`bool`, *optional*, defaults to `False`): + If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps + (one per detected instance). + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id`, or + `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to + `True`, or a tensor of shape `(num_instances, height, width)` if return_binary_maps is set to `True`. + Set to `None` if no mask if found above `threshold`. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- An integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + if return_coco_annotation and return_binary_maps: + raise ValueError("return_coco_annotation and return_binary_maps can not be both set to True.") + + # [batch_size, num_queries, num_classes+1] + class_queries_logits = outputs.class_queries_logits + # [batch_size, num_queries, height, width] + masks_queries_logits = outputs.masks_queries_logits + + # Scale back to preprocessed image size - (384, 384) for all models + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False + ) + + device = masks_queries_logits.device + num_classes = class_queries_logits.shape[-1] - 1 + num_queries = class_queries_logits.shape[-2] + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(class_queries_logits.shape[0]): + mask_pred = masks_queries_logits[i] + mask_cls = class_queries_logits[i] + + scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1] + labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) + + scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False) + labels_per_image = labels[topk_indices] + + topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor") + mask_pred = mask_pred[topk_indices] + pred_masks = (mask_pred > 0).float() + + # Calculate average mask prob + mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / ( + pred_masks.flatten(1).sum(1) + 1e-6 + ) + pred_scores = scores_per_image * mask_scores_per_image + pred_classes = labels_per_image + + segmentation = torch.zeros((384, 384)) - 1 + if target_sizes is not None: + segmentation = torch.zeros(target_sizes[i]) - 1 + pred_masks = torch.nn.functional.interpolate( + pred_masks.unsqueeze(0), size=target_sizes[i], mode="nearest" + )[0] + + instance_maps, segments = [], [] + current_segment_id = 0 + for j in range(num_queries): + score = pred_scores[j].item() + + if not torch.all(pred_masks[j] == 0) and score >= threshold: + segmentation[pred_masks[j] == 1] = current_segment_id + segments.append( + { + "id": current_segment_id, + "label_id": pred_classes[j].item(), + "was_fused": False, + "score": round(score, 6), + } + ) + current_segment_id += 1 + instance_maps.append(pred_masks[j]) + + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + # Return a concatenated tensor of binary instance maps + if return_binary_maps and len(instance_maps) != 0: + segmentation = torch.stack(instance_maps, dim=0) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[set[int]] = None, + target_sizes: Optional[list[tuple[int, int]]] = None, + ) -> list[dict]: + """ + Converts the output of [`Mask2FormerForUniversalSegmentationOutput`] into image panoptic segmentation + predictions. Only supports PyTorch. + + Args: + outputs ([`Mask2FormerForUniversalSegmentationOutput`]): + The outputs from [`Mask2FormerForUniversalSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If left to None, predictions will not be + resized. + + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set + to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized + to the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + # Scale back to preprocessed image size - (384, 384) for all models + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False + ) + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + def post_process_segmentation(): + raise NotImplementedError("Segmentation post-processing is not implemented for Mask2Former yet.") + + +__all__ = ["Mask2FormerImageProcessorFast"] diff --git a/src/transformers/models/maskformer/__init__.py b/src/transformers/models/maskformer/__init__.py index 144ae7d0a7..3a91c136c2 100644 --- a/src/transformers/models/maskformer/__init__.py +++ b/src/transformers/models/maskformer/__init__.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from .configuration_maskformer_swin import * from .feature_extraction_maskformer import * from .image_processing_maskformer import * + from .image_processing_maskformer_fast import * from .modeling_maskformer import * from .modeling_maskformer_swin import * else: diff --git a/src/transformers/models/maskformer/image_processing_maskformer.py b/src/transformers/models/maskformer/image_processing_maskformer.py index 91bdf020f2..52dc99b4c3 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer.py +++ b/src/transformers/models/maskformer/image_processing_maskformer.py @@ -67,6 +67,46 @@ if is_torch_available(): from torch import nn +# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = int(raw_size * height / width) + else: + oh = int(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = int(raw_size * width / height) + else: + ow = int(size * width / height) + + return (oh, ow) + + # Copied from transformers.models.detr.image_processing_detr.max_across_indices def max_across_indices(values: Iterable[Any]) -> list[Any]: """ @@ -399,6 +439,10 @@ class MaskFormerImageProcessor(BaseImageProcessor): The background label will be replaced by `ignore_index`. num_labels (`int`, *optional*): The number of labels in the segmentation map. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. """ @@ -422,6 +466,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): ignore_index: Optional[int] = None, do_reduce_labels: bool = False, num_labels: Optional[int] = None, + pad_size: Optional[dict[str, int]] = None, **kwargs, ): super().__init__(**kwargs) @@ -445,6 +490,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): self.ignore_index = ignore_index self.do_reduce_labels = do_reduce_labels self.num_labels = num_labels + self.pad_size = pad_size @classmethod def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): @@ -700,6 +746,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): return_tensors: Optional[Union[str, TensorType]] = None, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[dict[str, int]] = None, ) -> BatchFeature: do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size @@ -713,6 +760,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): image_std = image_std if image_std is not None else self.image_std ignore_index = ignore_index if ignore_index is not None else self.ignore_index do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels + pad_size = self.pad_size if pad_size is None else pad_size if not valid_images(images): raise ValueError( @@ -777,6 +825,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): do_reduce_labels, return_tensors, input_data_format=data_format, + pad_size=pad_size, ) return encoded_inputs @@ -808,7 +857,6 @@ class MaskFormerImageProcessor(BaseImageProcessor): ) return padded_image - # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.pad def pad( self, images: list[np.ndarray], @@ -817,6 +865,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[dict[str, int]] = None, ) -> BatchFeature: """ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width @@ -840,13 +889,21 @@ class MaskFormerImageProcessor(BaseImageProcessor): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. """ - pad_size = get_max_height_width(images, input_data_format=input_data_format) + pad_size = pad_size if pad_size is not None else self.pad_size + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images, input_data_format=input_data_format) padded_images = [ self._pad_image( image, - pad_size, + padded_size, constant_values=constant_values, data_format=data_format, input_data_format=input_data_format, @@ -857,7 +914,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): if return_pixel_mask: masks = [ - make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format) + make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format) for image in images ] data["pixel_mask"] = masks @@ -873,6 +930,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): do_reduce_labels: bool = False, return_tensors: Optional[Union[str, TensorType]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[dict[str, int]] = None, ): """ Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. @@ -909,6 +967,11 @@ class MaskFormerImageProcessor(BaseImageProcessor): If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` objects. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -930,7 +993,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): input_data_format = infer_channel_dimension_format(pixel_values_list[0]) encoded_inputs = self.pad( - pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format + pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format, pad_size=pad_size ) if segmentation_maps is not None: diff --git a/src/transformers/models/maskformer/image_processing_maskformer_fast.py b/src/transformers/models/maskformer/image_processing_maskformer_fast.py new file mode 100644 index 0000000000..bdd13afbe6 --- /dev/null +++ b/src/transformers/models/maskformer/image_processing_maskformer_fast.py @@ -0,0 +1,775 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for MaskFormer.""" + +import math +import warnings +from typing import TYPE_CHECKING, Any, Optional, Union + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + SizeDict, + get_image_size_for_max_height_width, + get_max_height_width, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from .image_processing_maskformer import ( + compute_segments, + convert_segmentation_to_rle, + get_size_with_aspect_ratio, + remove_low_and_no_objects, +) + + +logger = logging.get_logger(__name__) + + +if TYPE_CHECKING: + from transformers import MaskFormerForInstanceSegmentationOutput + + +if is_torch_available(): + import torch + from torch import nn + + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + +def convert_segmentation_map_to_binary_masks_fast( + segmentation_map: "torch.Tensor", + instance_id_to_semantic_id: Optional[dict[int, int]] = None, + ignore_index: Optional[int] = None, + do_reduce_labels: bool = False, +): + if do_reduce_labels and ignore_index is None: + raise ValueError("If `do_reduce_labels` is True, `ignore_index` must be provided.") + + if do_reduce_labels: + segmentation_map = torch.where(segmentation_map == 0, ignore_index, segmentation_map - 1) + + all_labels = torch.unique(segmentation_map) + + if ignore_index is not None: + all_labels = all_labels[all_labels != ignore_index] # drop background label if applicable + + binary_masks = [(segmentation_map == i) for i in all_labels] + if binary_masks: + binary_masks = torch.stack(binary_masks, dim=0) + else: + binary_masks = torch.zeros((0, *segmentation_map.shape), device=segmentation_map.device) + + # Convert instance ids to class ids + if instance_id_to_semantic_id is not None: + labels = torch.zeros(all_labels.shape[0], device=segmentation_map.device) + + for i, label in enumerate(all_labels): + class_id = instance_id_to_semantic_id[(label.item() + 1 if do_reduce_labels else label.item())] + labels[i] = class_id - 1 if do_reduce_labels else class_id + else: + labels = all_labels + return binary_masks.float(), labels.long() + + +class MaskFormerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + size_divisor (`int`, *optional*, defaults to 32): + Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in + Swin Transformer. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). + The background label will be replaced by `ignore_index`. + num_labels (`int`, *optional*): + The number of labels in the segmentation map. + do_pad (`bool`, *optional*, defaults to `True`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + + size_divisor: Optional[int] + ignore_index: Optional[int] + do_reduce_labels: Optional[bool] + num_labels: Optional[int] + do_pad: Optional[bool] + pad_size: Optional[dict[str, int]] + + +@auto_docstring +class MaskFormerImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"shortest_edge": 800, "longest_edge": 1333} + default_to_square = False + do_resize = True + do_rescale = True + rescale_factor = 1 / 255 + do_normalize = True + do_pad = True + model_input_names = ["pixel_values", "pixel_mask"] + size_divisor = 32 + do_reduce_labels = False + valid_kwargs = MaskFormerFastImageProcessorKwargs + + @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.44.0") + @deprecate_kwarg("size_divisibility", new_name="size_divisor", version="4.41.0") + @deprecate_kwarg("max_size", version="4.27.0", warn_if_greater_or_equal_version=True) + def __init__(self, **kwargs: Unpack[MaskFormerFastImageProcessorKwargs]) -> None: + if "pad_and_return_pixel_mask" in kwargs: + kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask") + + size = kwargs.pop("size", None) + max_size = kwargs.pop("max_size", None) + + if size is None and max_size is not None: + size = self.size + size["longest_edge"] = max_size + elif size is None: + size = self.size + + self.size = get_size_dict(size, max_size=max_size, default_to_square=False) + + super().__init__(**kwargs) + + @classmethod + def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `MaskFormerImageProcessor.from_pretrained(checkpoint, max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "size_divisibility" in kwargs: + image_processor_dict["size_divisor"] = kwargs.pop("size_divisibility") + if "reduce_labels" in image_processor_dict: + image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels") + return super().from_dict(image_processor_dict, **kwargs) + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. This method calls the superclass method and then removes the + `_max_size` attribute from the dictionary. + """ + image_processor_dict = super().to_dict() + image_processor_dict.pop("_max_size", None) + return image_processor_dict + + def reduce_label(self, labels: list["torch.Tensor"]): + for idx in range(len(labels)): + label = labels[idx] + label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label) + label = label - 1 + label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label) + labels[idx] = label + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + size_divisor: int = 0, + interpolation: "F.InterpolationMode" = None, + **kwargs, + ) -> torch.Tensor: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + size_divisor (`int`, *optional*, defaults to 0): + If `size_divisor` is given, the output image size will be divisible by the number. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + Resampling filter to use if resizing the image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if size.shortest_edge and size.longest_edge: + # Resize the image so that the shortest edge or the longest edge is of the given size + # while maintaining the aspect ratio of the original image. + new_size = get_size_with_aspect_ratio( + image.size()[-2:], + size["shortest_edge"], + size["longest_edge"], + ) + elif size.max_height and size.max_width: + new_size = get_image_size_for_max_height_width(image.size()[-2:], size["max_height"], size["max_width"]) + elif size.height and size.width: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + if size_divisor > 0: + height, width = new_size + height = int(math.ceil(height / size_divisor) * size_divisor) + width = int(math.ceil(width / size_divisor) * size_divisor) + new_size = (height, width) + + image = F.resize( + image, + size=new_size, + interpolation=interpolation, + **kwargs, + ) + return image + + def pad( + self, + images: torch.Tensor, + padded_size: tuple[int, int], + segmentation_maps: Optional[torch.Tensor] = None, + fill: int = 0, + ignore_index: int = 255, + ) -> BatchFeature: + original_size = images.size()[-2:] + padding_bottom = padded_size[0] - original_size[0] + padding_right = padded_size[1] - original_size[1] + if padding_bottom < 0 or padding_right < 0: + raise ValueError( + f"Padding dimensions are negative. Please make sure that the padded size is larger than the " + f"original size. Got padded size: {padded_size}, original size: {original_size}." + ) + if original_size != padded_size: + padding = [0, 0, padding_right, padding_bottom] + images = F.pad(images, padding, fill=fill) + if segmentation_maps is not None: + segmentation_maps = F.pad(segmentation_maps, padding, fill=ignore_index) + + # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + pixel_mask = torch.zeros((images.shape[0], *padded_size), dtype=torch.int64, device=images.device) + pixel_mask[:, : original_size[0], : original_size[1]] = 1 + + return images, pixel_mask, segmentation_maps + + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]] = None, + **kwargs: Unpack[MaskFormerFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps. + instance_id_to_semantic_id (`Union[list[dict[int, int]], dict[int, int]]`, *optional*): + A mapping from instance IDs to semantic IDs. + """ + return super().preprocess( + images, + segmentation_maps, + instance_id_to_semantic_id, + **kwargs, + ) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: ImageInput, + instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[MaskFormerFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + To be overriden by subclasses when image-like inputs other than images should be processed. + It can be used for segmentation maps, depth maps, etc. + """ + # Prepare input images + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + if segmentation_maps is not None: + segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + return self._preprocess(images, segmentation_maps, instance_id_to_semantic_id, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + segmentation_maps: Optional["torch.Tensor"], + instance_id_to_semantic_id: Optional[dict[int, int]], + do_resize: Optional[bool], + size: Optional[dict[str, int]], + pad_size: Optional[dict[str, int]], + size_divisor: Optional[int], + interpolation: Optional[Union["PILImageResampling", "F.InterpolationMode"]], + do_rescale: Optional[bool], + rescale_factor: Optional[float], + do_normalize: Optional[bool], + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + ignore_index: Optional[int], + do_reduce_labels: Optional[bool], + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + if segmentation_maps is not None and len(images) != len(segmentation_maps): + raise ValueError("Images and segmentation maps must have the same length.") + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + if segmentation_maps is not None: + grouped_segmentation_maps, grouped_segmentation_maps_index = group_images_by_shape( + segmentation_maps, disable_grouping=disable_grouping + ) + resized_segmentation_maps_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, size=size, size_divisor=size_divisor, interpolation=interpolation + ) + if segmentation_maps is not None: + stacked_segmentation_maps = self.resize( + image=grouped_segmentation_maps[shape], + size=size, + size_divisor=size_divisor, + interpolation=F.InterpolationMode.NEAREST_EXACT, + ) + resized_images_grouped[shape] = stacked_images + if segmentation_maps is not None: + resized_segmentation_maps_grouped[shape] = stacked_segmentation_maps + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + if segmentation_maps is not None: + resized_segmentation_maps = reorder_images( + resized_segmentation_maps_grouped, grouped_segmentation_maps_index + ) + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(resized_images) + + if segmentation_maps is not None: + mask_labels = [] + class_labels = [] + # Convert to list of binary masks and labels + for idx, segmentation_map in enumerate(resized_segmentation_maps): + if isinstance(instance_id_to_semantic_id, list): + instance_id = instance_id_to_semantic_id[idx] + else: + instance_id = instance_id_to_semantic_id + # Use instance2class_id mapping per image + masks, classes = convert_segmentation_map_to_binary_masks_fast( + segmentation_map.squeeze(0), + instance_id, + ignore_index=ignore_index, + do_reduce_labels=do_reduce_labels, + ) + mask_labels.append(masks) + class_labels.append(classes) + + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + processed_pixel_masks_grouped = {} + if segmentation_maps is not None: + grouped_segmentation_maps, grouped_segmentation_maps_index = group_images_by_shape( + mask_labels, disable_grouping=disable_grouping + ) + processed_segmentation_maps_grouped = {} + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + padded_images, pixel_masks, padded_segmentation_maps = self.pad( + images=stacked_images, + segmentation_maps=grouped_segmentation_maps[shape] if segmentation_maps is not None else None, + padded_size=padded_size, + ignore_index=ignore_index, + ) + processed_images_grouped[shape] = padded_images + processed_pixel_masks_grouped[shape] = pixel_masks + if segmentation_maps is not None: + processed_segmentation_maps_grouped[shape] = padded_segmentation_maps.squeeze(1) + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_pixel_masks = reorder_images(processed_pixel_masks_grouped, grouped_images_index) + encoded_inputs = BatchFeature( + data={ + "pixel_values": torch.stack(processed_images, dim=0) if return_tensors else processed_images, + "pixel_mask": torch.stack(processed_pixel_masks, dim=0) if return_tensors else processed_pixel_masks, + }, + tensor_type=return_tensors, + ) + if segmentation_maps is not None: + mask_labels = reorder_images(processed_segmentation_maps_grouped, grouped_segmentation_maps_index) + # we cannot batch them since they don't share a common class size + encoded_inputs["mask_labels"] = mask_labels + encoded_inputs["class_labels"] = class_labels + + return encoded_inputs + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_segmentation + def post_process_segmentation( + self, outputs: "MaskFormerForInstanceSegmentationOutput", target_size: Optional[tuple[int, int]] = None + ) -> "torch.Tensor": + """ + Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image segmentation predictions. Only + supports PyTorch. + + Args: + outputs ([`MaskFormerForInstanceSegmentationOutput`]): + The outputs from [`MaskFormerForInstanceSegmentation`]. + + target_size (`tuple[int, int]`, *optional*): + If set, the `masks_queries_logits` will be resized to `target_size`. + + Returns: + `torch.Tensor`: + A tensor of shape (`batch_size, num_class_labels, height, width`). + """ + warnings.warn( + "`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_instance_segmentation`", + FutureWarning, + ) + + # class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1] + class_queries_logits = outputs.class_queries_logits + # masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH] + masks_queries_logits = outputs.masks_queries_logits + if target_size is not None: + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, + size=target_size, + mode="bilinear", + align_corners=False, + ) + # remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + # mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH] + masks_probs = masks_queries_logits.sigmoid() + # now we want to sum over the queries, + # $ out_{c,h,w} = \sum_q p_{q,c} * m_{q,h,w} $ + # where $ softmax(p) \in R^{q, c} $ is the mask classes + # and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities + # b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + + return segmentation + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_semantic_segmentation + def post_process_semantic_segmentation( + self, outputs, target_sizes: Optional[list[tuple[int, int]]] = None + ) -> "torch.Tensor": + """ + Converts the output of [`MaskFormerForInstanceSegmentation`] into semantic segmentation maps. Only supports + PyTorch. + + Args: + outputs ([`MaskFormerForInstanceSegmentation`]): + Raw outputs of the model. + target_sizes (`list[tuple[int, int]]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + Returns: + `list[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = torch.nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_instance_segmentation + def post_process_instance_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[list[tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + return_binary_maps: Optional[bool] = False, + ) -> list[dict]: + """ + Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into instance segmentation predictions. Only + supports PyTorch. If instances could overlap, set either return_coco_annotation or return_binary_maps + to `True` to get the correct segmentation result. + + Args: + outputs ([`MaskFormerForInstanceSegmentation`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`list[Tuple]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + return_coco_annotation (`bool`, *optional*, defaults to `False`): + If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format. + return_binary_maps (`bool`, *optional*, defaults to `False`): + If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps + (one per detected instance). + Returns: + `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id`, or + `list[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to + `True`, or a tensor of shape `(num_instances, height, width)` if return_binary_maps is set to `True`. + Set to `None` if no mask if found above `threshold`. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- An integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + if return_coco_annotation and return_binary_maps: + raise ValueError("return_coco_annotation and return_binary_maps can not be both set to True.") + + # [batch_size, num_queries, num_classes+1] + class_queries_logits = outputs.class_queries_logits + # [batch_size, num_queries, height, width] + masks_queries_logits = outputs.masks_queries_logits + + device = masks_queries_logits.device + num_classes = class_queries_logits.shape[-1] - 1 + num_queries = class_queries_logits.shape[-2] + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(class_queries_logits.shape[0]): + mask_pred = masks_queries_logits[i] + mask_cls = class_queries_logits[i] + + scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1] + labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) + + scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False) + labels_per_image = labels[topk_indices] + + topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor") + mask_pred = mask_pred[topk_indices] + pred_masks = (mask_pred > 0).float() + + # Calculate average mask prob + mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / ( + pred_masks.flatten(1).sum(1) + 1e-6 + ) + pred_scores = scores_per_image * mask_scores_per_image + pred_classes = labels_per_image + + segmentation = torch.zeros(masks_queries_logits.shape[2:]) - 1 + if target_sizes is not None: + segmentation = torch.zeros(target_sizes[i]) - 1 + pred_masks = torch.nn.functional.interpolate( + pred_masks.unsqueeze(0), size=target_sizes[i], mode="nearest" + )[0] + + instance_maps, segments = [], [] + current_segment_id = 0 + for j in range(num_queries): + score = pred_scores[j].item() + + if not torch.all(pred_masks[j] == 0) and score >= threshold: + segmentation[pred_masks[j] == 1] = current_segment_id + segments.append( + { + "id": current_segment_id, + "label_id": pred_classes[j].item(), + "was_fused": False, + "score": round(score, 6), + } + ) + current_segment_id += 1 + instance_maps.append(pred_masks[j]) + + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + # Return a concatenated tensor of binary instance maps + if return_binary_maps and len(instance_maps) != 0: + segmentation = torch.stack(instance_maps, dim=0) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + # Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_panoptic_segmentation + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[set[int]] = None, + target_sizes: Optional[list[tuple[int, int]]] = None, + ) -> list[dict]: + """ + Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation + predictions. Only supports PyTorch. + + Args: + outputs ([`MaskFormerForInstanceSegmentationOutput`]): + The outputs from [`MaskFormerForInstanceSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`list[Tuple]`, *optional*): + List of length (batch_size), where each list item (`tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If left to None, predictions will not be + resized. + + Returns: + `list[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set + to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized + to the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: list[dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["MaskFormerImageProcessorFast"] diff --git a/tests/models/mask2former/test_image_processing_mask2former.py b/tests/models/mask2former/test_image_processing_mask2former.py index 457f5ad0af..c2685df19f 100644 --- a/tests/models/mask2former/test_image_processing_mask2former.py +++ b/tests/models/mask2former/test_image_processing_mask2former.py @@ -21,7 +21,7 @@ from huggingface_hub import hf_hub_download from transformers.image_utils import ChannelDimension from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -34,6 +34,9 @@ if is_torch_available(): from transformers.models.mask2former.image_processing_mask2former import binary_mask_to_rle from transformers.models.mask2former.modeling_mask2former import Mask2FormerForUniversalSegmentationOutput + if is_torchvision_available(): + from transformers import Mask2FormerImageProcessorFast + if is_vision_available(): from PIL import Image @@ -54,6 +57,7 @@ class Mask2FormerImageProcessingTester: num_labels=10, do_reduce_labels=True, ignore_index=255, + pad_size=None, ): self.parent = parent self.batch_size = batch_size @@ -66,6 +70,7 @@ class Mask2FormerImageProcessingTester: self.image_mean = image_mean self.image_std = image_std self.size_divisor = 0 + self.pad_size = pad_size # for the post_process_functions self.batch_size = 2 self.num_queries = 3 @@ -87,6 +92,7 @@ class Mask2FormerImageProcessingTester: "num_labels": self.num_labels, "do_reduce_labels": self.do_reduce_labels, "ignore_index": self.ignore_index, + "pad_size": self.pad_size, } def get_expected_values(self, image_inputs, batched=False): @@ -145,10 +151,26 @@ class Mask2FormerImageProcessingTester: ) +# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_single_inputs +def prepare_semantic_single_inputs(): + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + example = ds[0] + return example["image"], example["map"] + + +# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_batch_inputs +def prepare_semantic_batch_inputs(): + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + return list(ds["image"][:2]), list(ds["map"][:2]) + + @require_torch @require_vision class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = Mask2FormerImageProcessor if (is_vision_available() and is_torch_available()) else None + fast_image_processing_class = ( + Mask2FormerImageProcessorFast if (is_vision_available() and is_torchvision_available()) else None + ) def setUp(self): super().setUp() @@ -159,25 +181,27 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "ignore_index")) - self.assertTrue(hasattr(image_processing, "num_labels")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "ignore_index")) + self.assertTrue(hasattr(image_processing, "num_labels")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 32, "longest_edge": 1333}) - self.assertEqual(image_processor.size_divisor, 0) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 32, "longest_edge": 1333}) + self.assertEqual(image_processor.size_divisor, 0) - image_processor = self.image_processing_class.from_dict( - self.image_processor_dict, size=42, max_size=84, size_divisibility=8 - ) - self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84}) - self.assertEqual(image_processor.size_divisor, 8) + image_processor = image_processing_class.from_dict( + self.image_processor_dict, size=42, max_size=84, size_divisibility=8 + ) + self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84}) + self.assertEqual(image_processor.size_divisor, 8) def comm_get_image_processing_inputs( self, @@ -225,15 +249,16 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase def test_with_size_divisor(self): size_divisors = [8, 16, 32] weird_input_sizes = [(407, 802), (582, 1094)] - for size_divisor in size_divisors: - image_processor_dict = {**self.image_processor_dict, **{"size_divisor": size_divisor}} - image_processing = self.image_processing_class(**image_processor_dict) - for weird_input_size in weird_input_sizes: - inputs = image_processing([np.ones((3, *weird_input_size))], return_tensors="pt") - pixel_values = inputs["pixel_values"] - # check if divisible - self.assertTrue((pixel_values.shape[-1] % size_divisor) == 0) - self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0) + for image_processing_class in self.image_processor_list: + for size_divisor in size_divisors: + image_processor_dict = {**self.image_processor_dict, **{"size_divisor": size_divisor}} + image_processing = image_processing_class(**image_processor_dict) + for weird_input_size in weird_input_sizes: + inputs = image_processing([np.ones((3, *weird_input_size))], return_tensors="pt") + pixel_values = inputs["pixel_values"] + # check if divisible + self.assertTrue((pixel_values.shape[-1] % size_divisor) == 0) + self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0) def test_call_with_segmentation_maps(self): def common( @@ -463,81 +488,85 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase self.assertEqual(rle[1], 45) def test_post_process_semantic_segmentation(self): - fature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) - outputs = self.image_processor_tester.get_fake_mask2former_outputs() + for image_processing_class in self.image_processor_list: + fature_extractor = image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_mask2former_outputs() - segmentation = fature_extractor.post_process_semantic_segmentation(outputs) + segmentation = fature_extractor.post_process_semantic_segmentation(outputs) - self.assertEqual(len(segmentation), self.image_processor_tester.batch_size) - self.assertEqual(segmentation[0].shape, (384, 384)) + self.assertEqual(len(segmentation), self.image_processor_tester.batch_size) + self.assertEqual(segmentation[0].shape, (384, 384)) - target_sizes = [(1, 4) for i in range(self.image_processor_tester.batch_size)] - segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) + target_sizes = [(1, 4) for i in range(self.image_processor_tester.batch_size)] + segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) - self.assertEqual(segmentation[0].shape, target_sizes[0]) + self.assertEqual(segmentation[0].shape, target_sizes[0]) def test_post_process_instance_segmentation(self): - image_processor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) - outputs = self.image_processor_tester.get_fake_mask2former_outputs() - segmentation = image_processor.post_process_instance_segmentation(outputs, threshold=0) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_mask2former_outputs() + segmentation = image_processor.post_process_instance_segmentation(outputs, threshold=0) - self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) - for el in segmentation: - self.assertTrue("segmentation" in el) - self.assertTrue("segments_info" in el) - self.assertEqual(type(el["segments_info"]), list) - self.assertEqual(el["segmentation"].shape, (384, 384)) + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(el["segmentation"].shape, (384, 384)) - segmentation = image_processor.post_process_instance_segmentation( - outputs, threshold=0, return_binary_maps=True - ) + segmentation = image_processor.post_process_instance_segmentation( + outputs, threshold=0, return_binary_maps=True + ) - self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) - for el in segmentation: - self.assertTrue("segmentation" in el) - self.assertTrue("segments_info" in el) - self.assertEqual(type(el["segments_info"]), list) - self.assertEqual(len(el["segmentation"].shape), 3) - self.assertEqual(el["segmentation"].shape[1:], (384, 384)) + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(len(el["segmentation"].shape), 3) + self.assertEqual(el["segmentation"].shape[1:], (384, 384)) def test_post_process_panoptic_segmentation(self): - image_processing = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) - outputs = self.image_processor_tester.get_fake_mask2former_outputs() - segmentation = image_processing.post_process_panoptic_segmentation(outputs, threshold=0) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_mask2former_outputs() + segmentation = image_processing.post_process_panoptic_segmentation(outputs, threshold=0) - self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) - for el in segmentation: - self.assertTrue("segmentation" in el) - self.assertTrue("segments_info" in el) - self.assertEqual(type(el["segments_info"]), list) - self.assertEqual(el["segmentation"].shape, (384, 384)) + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(el["segmentation"].shape, (384, 384)) def test_post_process_label_fusing(self): - image_processor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) - outputs = self.image_processor_tester.get_fake_mask2former_outputs() + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_mask2former_outputs() - segmentation = image_processor.post_process_panoptic_segmentation( - outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0 - ) - unfused_segments = [el["segments_info"] for el in segmentation] + segmentation = image_processor.post_process_panoptic_segmentation( + outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0 + ) + unfused_segments = [el["segments_info"] for el in segmentation] - fused_segmentation = image_processor.post_process_panoptic_segmentation( - outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1} - ) - fused_segments = [el["segments_info"] for el in fused_segmentation] + fused_segmentation = image_processor.post_process_panoptic_segmentation( + outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1} + ) + fused_segments = [el["segments_info"] for el in fused_segmentation] - for el_unfused, el_fused in zip(unfused_segments, fused_segments): - if len(el_unfused) == 0: - self.assertEqual(len(el_unfused), len(el_fused)) - continue + for el_unfused, el_fused in zip(unfused_segments, fused_segments): + if len(el_unfused) == 0: + self.assertEqual(len(el_unfused), len(el_fused)) + continue - # Get number of segments to be fused - fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}] - num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1 - # Expected number of segments after fusing - expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse - num_segments_fused = max([el["id"] for el in el_fused]) - self.assertEqual(num_segments_fused, expected_num_segments) + # Get number of segments to be fused + fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}] + num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1 + # Expected number of segments after fusing + expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse + num_segments_fused = max([el["id"] for el in el_fused]) + self.assertEqual(num_segments_fused, expected_num_segments) def test_removed_deprecated_kwargs(self): image_processor_dict = dict(self.image_processor_dict) @@ -545,9 +574,58 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase image_processor_dict["reduce_labels"] = True # test we are able to create the image processor with the deprecated kwargs - image_processor = self.image_processing_class(**image_processor_dict) - self.assertEqual(image_processor.do_reduce_labels, True) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**image_processor_dict) + self.assertEqual(image_processor.do_reduce_labels, True) - # test we still support reduce_labels with config - image_processor = self.image_processing_class.from_dict(image_processor_dict) - self.assertEqual(image_processor.do_reduce_labels, True) + # test we still support reduce_labels with config + image_processor = image_processing_class.from_dict(image_processor_dict) + self.assertEqual(image_processor.do_reduce_labels, True) + + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_image, dummy_map = prepare_semantic_single_inputs() + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values) + for mask_label_slow, mask_label_fast in zip(image_encoding_slow.mask_labels, image_encoding_fast.mask_labels): + self._assert_slow_fast_tensors_equivalence(mask_label_slow, mask_label_fast) + for class_label_slow, class_label_fast in zip( + image_encoding_slow.class_labels, image_encoding_fast.class_labels + ): + self._assert_slow_fast_tensors_equivalence(class_label_slow.float(), class_label_fast.float()) + + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + dummy_images, dummy_maps = prepare_semantic_batch_inputs() + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") + + self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) + for mask_label_slow, mask_label_fast in zip(encoding_slow.mask_labels, encoding_fast.mask_labels): + self._assert_slow_fast_tensors_equivalence(mask_label_slow, mask_label_fast) + for class_label_slow, class_label_fast in zip(encoding_slow.class_labels, encoding_fast.class_labels): + self._assert_slow_fast_tensors_equivalence(class_label_slow.float(), class_label_fast.float()) diff --git a/tests/models/maskformer/test_image_processing_maskformer.py b/tests/models/maskformer/test_image_processing_maskformer.py index be31e4f692..01a9f0d086 100644 --- a/tests/models/maskformer/test_image_processing_maskformer.py +++ b/tests/models/maskformer/test_image_processing_maskformer.py @@ -20,7 +20,7 @@ from datasets import load_dataset from huggingface_hub import hf_hub_download from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -33,6 +33,9 @@ if is_torch_available(): from transformers.models.maskformer.image_processing_maskformer import binary_mask_to_rle from transformers.models.maskformer.modeling_maskformer import MaskFormerForInstanceSegmentationOutput + if is_torchvision_available(): + from transformers import MaskFormerImageProcessorFast + if is_vision_available(): from PIL import Image @@ -144,10 +147,26 @@ class MaskFormerImageProcessingTester: ) +# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_single_inputs +def prepare_semantic_single_inputs(): + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + example = ds[0] + return example["image"], example["map"] + + +# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_batch_inputs +def prepare_semantic_batch_inputs(): + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + return list(ds["image"][:2]), list(ds["map"][:2]) + + @require_torch @require_vision class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = MaskFormerImageProcessor if (is_vision_available() and is_torch_available()) else None + fast_image_processing_class = ( + MaskFormerImageProcessorFast if (is_vision_available() and is_torchvision_available()) else None + ) def setUp(self): super().setUp() @@ -158,25 +177,27 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase) return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "ignore_index")) - self.assertTrue(hasattr(image_processing, "num_labels")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "ignore_index")) + self.assertTrue(hasattr(image_processing, "num_labels")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 32, "longest_edge": 1333}) - self.assertEqual(image_processor.size_divisor, 0) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 32, "longest_edge": 1333}) + self.assertEqual(image_processor.size_divisor, 0) - image_processor = self.image_processing_class.from_dict( - self.image_processor_dict, size=42, max_size=84, size_divisibility=8 - ) - self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84}) - self.assertEqual(image_processor.size_divisor, 8) + image_processor = image_processing_class.from_dict( + self.image_processor_dict, size=42, max_size=84, size_divisibility=8 + ) + self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84}) + self.assertEqual(image_processor.size_divisor, 8) def comm_get_image_processing_inputs( self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np" @@ -211,15 +232,16 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase) def test_with_size_divisor(self): size_divisors = [8, 16, 32] weird_input_sizes = [(407, 802), (582, 1094)] - for size_divisor in size_divisors: - image_processor_dict = {**self.image_processor_dict, **{"size_divisor": size_divisor}} - image_processing = self.image_processing_class(**image_processor_dict) - for weird_input_size in weird_input_sizes: - inputs = image_processing([np.ones((3, *weird_input_size))], return_tensors="pt") - pixel_values = inputs["pixel_values"] - # check if divisible - self.assertTrue((pixel_values.shape[-1] % size_divisor) == 0) - self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0) + for image_processing_class in self.image_processor_list: + for size_divisor in size_divisors: + image_processor_dict = {**self.image_processor_dict, **{"size_divisor": size_divisor}} + image_processing = image_processing_class(**image_processor_dict) + for weird_input_size in weird_input_sizes: + inputs = image_processing([np.ones((3, *weird_input_size))], return_tensors="pt") + pixel_values = inputs["pixel_values"] + # check if divisible + self.assertTrue((pixel_values.shape[-1] % size_divisor) == 0) + self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0) def test_call_with_segmentation_maps(self): def common(is_instance_map=False, segmentation_type=None): @@ -417,116 +439,122 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase) self.assertEqual(rle[1], 45) def test_post_process_segmentation(self): - fature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) - outputs = self.image_processor_tester.get_fake_maskformer_outputs() - segmentation = fature_extractor.post_process_segmentation(outputs) + for image_processing_class in self.image_processor_list: + feature_extractor = image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_maskformer_outputs() + segmentation = feature_extractor.post_process_segmentation(outputs) - self.assertEqual( - segmentation.shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_classes, - self.image_processor_tester.height, - self.image_processor_tester.width, - ), - ) + self.assertEqual( + segmentation.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_classes, + self.image_processor_tester.height, + self.image_processor_tester.width, + ), + ) - target_size = (1, 4) - segmentation = fature_extractor.post_process_segmentation(outputs, target_size=target_size) + target_size = (1, 4) + segmentation = feature_extractor.post_process_segmentation(outputs, target_size=target_size) - self.assertEqual( - segmentation.shape, - (self.image_processor_tester.batch_size, self.image_processor_tester.num_classes, *target_size), - ) + self.assertEqual( + segmentation.shape, + (self.image_processor_tester.batch_size, self.image_processor_tester.num_classes, *target_size), + ) def test_post_process_semantic_segmentation(self): - fature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) - outputs = self.image_processor_tester.get_fake_maskformer_outputs() + for image_processing_class in self.image_processor_list: + feature_extractor = image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_maskformer_outputs() - segmentation = fature_extractor.post_process_semantic_segmentation(outputs) + segmentation = feature_extractor.post_process_semantic_segmentation(outputs) - self.assertEqual(len(segmentation), self.image_processor_tester.batch_size) - self.assertEqual( - segmentation[0].shape, - ( - self.image_processor_tester.height, - self.image_processor_tester.width, - ), - ) + self.assertEqual(len(segmentation), self.image_processor_tester.batch_size) + self.assertEqual( + segmentation[0].shape, + ( + self.image_processor_tester.height, + self.image_processor_tester.width, + ), + ) - target_sizes = [(1, 4) for i in range(self.image_processor_tester.batch_size)] - segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) + target_sizes = [(1, 4) for i in range(self.image_processor_tester.batch_size)] + segmentation = feature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) - self.assertEqual(segmentation[0].shape, target_sizes[0]) + self.assertEqual(segmentation[0].shape, target_sizes[0]) def test_post_process_instance_segmentation(self): - image_processor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) - outputs = self.image_processor_tester.get_fake_maskformer_outputs() - segmentation = image_processor.post_process_instance_segmentation(outputs, threshold=0) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_maskformer_outputs() + segmentation = image_processor.post_process_instance_segmentation(outputs, threshold=0) - self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) - for el in segmentation: - self.assertTrue("segmentation" in el) - self.assertTrue("segments_info" in el) - self.assertEqual(type(el["segments_info"]), list) - self.assertEqual( - el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width) + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual( + el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width) + ) + + segmentation = image_processor.post_process_instance_segmentation( + outputs, threshold=0, return_binary_maps=True ) - segmentation = image_processor.post_process_instance_segmentation( - outputs, threshold=0, return_binary_maps=True - ) - - self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) - for el in segmentation: - self.assertTrue("segmentation" in el) - self.assertTrue("segments_info" in el) - self.assertEqual(type(el["segments_info"]), list) - self.assertEqual(len(el["segmentation"].shape), 3) - self.assertEqual( - el["segmentation"].shape[1:], (self.image_processor_tester.height, self.image_processor_tester.width) - ) + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(len(el["segmentation"].shape), 3) + self.assertEqual( + el["segmentation"].shape[1:], + (self.image_processor_tester.height, self.image_processor_tester.width), + ) def test_post_process_panoptic_segmentation(self): - image_processing = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) - outputs = self.image_processor_tester.get_fake_maskformer_outputs() - segmentation = image_processing.post_process_panoptic_segmentation(outputs, threshold=0) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_maskformer_outputs() + segmentation = image_processing.post_process_panoptic_segmentation(outputs, threshold=0) - self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) - for el in segmentation: - self.assertTrue("segmentation" in el) - self.assertTrue("segments_info" in el) - self.assertEqual(type(el["segments_info"]), list) - self.assertEqual( - el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width) - ) + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual( + el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width) + ) def test_post_process_label_fusing(self): - image_processor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) - outputs = self.image_processor_tester.get_fake_maskformer_outputs() + for image_processing_class in self.image_processor_list: + image_processor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_maskformer_outputs() - segmentation = image_processor.post_process_panoptic_segmentation( - outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0 - ) - unfused_segments = [el["segments_info"] for el in segmentation] + segmentation = image_processor.post_process_panoptic_segmentation( + outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0 + ) + unfused_segments = [el["segments_info"] for el in segmentation] - fused_segmentation = image_processor.post_process_panoptic_segmentation( - outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1} - ) - fused_segments = [el["segments_info"] for el in fused_segmentation] + fused_segmentation = image_processor.post_process_panoptic_segmentation( + outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1} + ) + fused_segments = [el["segments_info"] for el in fused_segmentation] - for el_unfused, el_fused in zip(unfused_segments, fused_segments): - if len(el_unfused) == 0: - self.assertEqual(len(el_unfused), len(el_fused)) - continue + for el_unfused, el_fused in zip(unfused_segments, fused_segments): + if len(el_unfused) == 0: + self.assertEqual(len(el_unfused), len(el_fused)) + continue - # Get number of segments to be fused - fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}] - num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1 - # Expected number of segments after fusing - expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse - num_segments_fused = max([el["id"] for el in el_fused]) - self.assertEqual(num_segments_fused, expected_num_segments) + # Get number of segments to be fused + fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}] + num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1 + # Expected number of segments after fusing + expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse + num_segments_fused = max([el["id"] for el in el_fused]) + self.assertEqual(num_segments_fused, expected_num_segments) def test_removed_deprecated_kwargs(self): image_processor_dict = dict(self.image_processor_dict) @@ -540,3 +568,50 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase) # test we still support reduce_labels with config image_processor = self.image_processing_class.from_dict(image_processor_dict) self.assertEqual(image_processor.do_reduce_labels, True) + + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_image, dummy_map = prepare_semantic_single_inputs() + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values) + for mask_label_slow, mask_label_fast in zip(image_encoding_slow.mask_labels, image_encoding_fast.mask_labels): + self._assert_slow_fast_tensors_equivalence(mask_label_slow, mask_label_fast) + for class_label_slow, class_label_fast in zip( + image_encoding_slow.class_labels, image_encoding_fast.class_labels + ): + self._assert_slow_fast_tensors_equivalence(class_label_slow.float(), class_label_fast.float()) + + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + dummy_images, dummy_maps = prepare_semantic_batch_inputs() + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) + for mask_label_slow, mask_label_fast in zip(encoding_slow.mask_labels, encoding_fast.mask_labels): + self._assert_slow_fast_tensors_equivalence(mask_label_slow, mask_label_fast) + for class_label_slow, class_label_fast in zip(encoding_slow.class_labels, encoding_fast.class_labels): + self._assert_slow_fast_tensors_equivalence(class_label_slow.float(), class_label_fast.float())