From de0dd3139df7fe6f521b54958829a59f89acb2c8 Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Fri, 18 Jul 2025 13:27:16 -0400 Subject: [PATCH] Add fast image processor SAM (#39385) * add fast image processor sam * nits --- docs/source/en/model_doc/sam.md | 11 +- .../models/auto/image_processing_auto.py | 4 +- .../models/auto/tokenization_auto.py | 9 + src/transformers/models/sam/__init__.py | 1 + .../models/sam/image_processing_sam.py | 5 + .../models/sam/image_processing_sam_fast.py | 865 ++++++++++++++++++ tests/models/sam/test_image_processing_sam.py | 301 ++++++ 7 files changed, 1191 insertions(+), 5 deletions(-) create mode 100644 src/transformers/models/sam/image_processing_sam_fast.py create mode 100644 tests/models/sam/test_image_processing_sam.py diff --git a/docs/source/en/model_doc/sam.md b/docs/source/en/model_doc/sam.md index cf5273e089..ac73c107b8 100644 --- a/docs/source/en/model_doc/sam.md +++ b/docs/source/en/model_doc/sam.md @@ -25,7 +25,7 @@ rendered properly in your Markdown viewer. SAM (Segment Anything Model) was proposed in [Segment Anything](https://huggingface.co/papers/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. -The model can be used to predict segmentation masks of any object of interest given an input image. +The model can be used to predict segmentation masks of any object of interest given an input image. ![example image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-output.png) @@ -37,9 +37,9 @@ Tips: - The model predicts binary masks that states the presence or not of the object of interest given an image. - The model predicts much better results if input 2D points and/or input bounding boxes are provided -- You can prompt multiple points for the same image, and predict a single mask. +- You can prompt multiple points for the same image, and predict a single mask. - Fine-tuning the model is not supported yet -- According to the paper, textual input should be also supported. However, at this time of writing this seems not to be supported according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844). +- According to the paper, textual input should be also supported. However, at this time of writing this seems not to be supported according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844). This model was contributed by [ybelkada](https://huggingface.co/ybelkada) and [ArthurZ](https://huggingface.co/ArthurZ). @@ -149,6 +149,11 @@ alt="drawing" width="900"/> [[autodoc]] SamImageProcessor +## SamImageProcessorFast + +[[autodoc]] SamImageProcessorFast + + ## SamVisionModel [[autodoc]] SamVisionModel diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index d3c4367eca..84e6a75b16 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -147,8 +147,8 @@ else: ("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), - ("sam", ("SamImageProcessor",)), - ("sam_hq", ("SamImageProcessor",)), + ("sam", ("SamImageProcessor", "SamImageProcessorFast")), + ("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")), ("segformer", ("SegformerImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)), ("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 9dd78dbeae..b28e1cfe4a 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -656,6 +656,15 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]]( ("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("vits", ("VitsTokenizer", None)), + ( + "voxtral", + ( + "MistralCommonTokenizer" + if is_mistral_common_available() + else ("LlamaTokenizer" if is_sentencepiece_available() else None), + "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None, + ), + ), ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)), ("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)), ("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)), diff --git a/src/transformers/models/sam/__init__.py b/src/transformers/models/sam/__init__.py index 68da4037a3..bb8a2b98e6 100644 --- a/src/transformers/models/sam/__init__.py +++ b/src/transformers/models/sam/__init__.py @@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure if TYPE_CHECKING: from .configuration_sam import * from .image_processing_sam import * + from .image_processing_sam_fast import * from .modeling_sam import * from .modeling_tf_sam import * from .processing_sam import * diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index e1150c7f0b..c431bb72ca 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -387,6 +387,11 @@ class SamImageProcessor(BaseImageProcessor): return segmentation_map, original_size + def __call__(self, images, segmentation_maps=None, **kwargs): + # Overrides the `__call__` method of the `BaseImageProcessor` class such that the images and segmentation maps can both + # be passed in as positional arguments. + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + @filter_out_non_signature_kwargs() def preprocess( self, diff --git a/src/transformers/models/sam/image_processing_sam_fast.py b/src/transformers/models/sam/image_processing_sam_fast.py new file mode 100644 index 0000000000..b50ba955be --- /dev/null +++ b/src/transformers/models/sam/image_processing_sam_fast.py @@ -0,0 +1,865 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SAM.""" + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Optional, Union + +import numpy as np +import torch + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + make_list_of_images, + pil_torch_interpolation_mapping, + validate_kwargs, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + from torch.nn import functional as F_t + +if is_torchvision_available() and is_torchvision_v2_available(): + from torchvision.ops.boxes import batched_nms + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.ops.boxes import batched_nms + from torchvision.transforms import functional as F + + +class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + do_pad (`bool`, *optional*, defaults to `True`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. + mask_size (`dict[str, int]`, *optional*): + The size `{"longest_edge": int}` to resize the segmentation maps to. + mask_pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width": int}` to pad the segmentation maps to. Must be larger than any segmentation + map size provided for preprocessing. + """ + + mask_size: Optional[dict[str, int]] + do_pad: Optional[bool] + pad_size: Optional[dict[str, int]] + mask_pad_size: Optional[dict[str, int]] + + +@auto_docstring +class SamImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"longest_edge": 1024} + mask_size = {"longest_edge": 256} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + valid_kwargs = SamFastImageProcessorKwargs + + do_pad = True + pad_size = {"height": 1024, "width": 1024} + mask_pad_size = {"height": 256, "width": 256} + + def __init__(self, **kwargs: Unpack[SamFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def pad_image(self, images: "torch.Tensor", pad_size: SizeDict): + """Pad images to the specified size.""" + output_height, output_width = pad_size.height, pad_size.width + input_height, input_width = images.shape[-2:] + pad_width = output_width - input_width + pad_height = output_height - input_height + padding = (0, 0, pad_width, pad_height) + return F.pad(images, padding) + + def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int): + """ + Compute the output size given input size and target long side length. + """ + oldh, oldw = old_shape + scale = longest_edge * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + newh = int(newh + 0.5) + neww = int(neww + 0.5) + return (newh, neww) + + def resize( + self, image: "torch.Tensor", size: SizeDict, interpolation: Optional["F.InterpolationMode"], **kwargs + ) -> "torch.Tensor": + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest + edge of the image will be resized to the specified size, while the other edge will be resized to + maintain the aspect ratio. + interpolation: + `F.InterpolationMode` filter to use when resizing the image e.g. `F.InterpolationMode.BICUBIC`. + + Returns: + `torch.Tensor`: The resized image. + """ + if not size.longest_edge: + raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}") + input_size = image.shape[-2:] + output_height, output_width = self._get_preprocess_shape(input_size, size.longest_edge) + return super().resize( + image, size=SizeDict(height=output_height, width=output_width), interpolation=interpolation, **kwargs + ) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + do_pad: bool, + pad_size: SizeDict, + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + if do_pad: + stacked_images = self.pad_image(stacked_images, pad_size) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return processed_images + + def _preprocess_segmentation_maps( + self, + segmentation_maps, + **kwargs, + ): + """Preprocesses segmentation maps.""" + processed_segmentation_maps = [] + for segmentation_map in segmentation_maps: + segmentation_map = self._process_image( + segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST + ) + + if segmentation_map.ndim == 2: + segmentation_map = segmentation_map[None, ...] + processed_segmentation_maps.append(segmentation_map) + + kwargs["do_rescale"] = False + kwargs["do_normalize"] = False + kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST] + kwargs["size"] = kwargs.pop("mask_size") + kwargs["pad_size"] = kwargs.pop("mask_pad_size") + processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) + + processed_segmentation_maps = processed_segmentation_maps.squeeze(1) # Remove channel dimension + + processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) + return processed_segmentation_maps + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + pad_size: Optional[SizeDict] = None, + mask_size: Optional[SizeDict] = None, + mask_pad_size: Optional[SizeDict] = None, + default_to_square: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[ChannelDimension] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if kwargs is None: + kwargs = {} + if size is not None: + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) + if pad_size is not None: + pad_size = SizeDict(**get_size_dict(pad_size, param_name="pad_size")) + if mask_size is not None: + mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size")) + if mask_pad_size is not None: + mask_pad_size = SizeDict(**get_size_dict(mask_pad_size, param_name="mask_pad_size")) + if isinstance(image_mean, list): + image_mean = tuple(image_mean) + if isinstance(image_std, list): + image_std = tuple(image_std) + if data_format is None: + data_format = ChannelDimension.FIRST + + kwargs["size"] = size + kwargs["pad_size"] = pad_size + kwargs["mask_size"] = mask_size + kwargs["mask_pad_size"] = mask_pad_size + kwargs["default_to_square"] = default_to_square + kwargs["image_mean"] = image_mean + kwargs["image_std"] = image_std + kwargs["data_format"] = data_format + + return kwargs + + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[SamFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + # Extract parameters that are only used for preparing the input images + do_convert_rgb = kwargs.pop("do_convert_rgb") + input_data_format = kwargs.pop("input_data_format") + device = kwargs.pop("device") + # Prepare input images + images = self._prepare_input_images( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + + # Prepare segmentation maps + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) + + # Update kwargs that need further processing before being validated + kwargs = self._further_process_kwargs(**kwargs) + + # Validate kwargs + self._validate_preprocess_kwargs(**kwargs) + + # torch resize uses interpolation instead of resample + resample = kwargs.pop("resample") + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + # Pop kwargs that are not needed in _preprocess + kwargs.pop("default_to_square") + kwargs.pop("data_format") + + original_sizes = [image.shape[-2:] for image in images] + + images = self._preprocess( + images=images, + **kwargs, + ) + reshaped_input_sizes = [image.shape[-2:] for image in images] + + if segmentation_maps is not None: + segmentation_maps = self._preprocess_segmentation_maps( + segmentation_maps=segmentation_maps, + **kwargs, + ) + + return BatchFeature( + data={ + "pixel_values": images, + "labels": segmentation_maps, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + }, + tensor_type=kwargs["return_tensors"], + ) + + return BatchFeature( + data={ + "pixel_values": images, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + }, + tensor_type=kwargs["return_tensors"], + ) + + def generate_crop_boxes( + self, + image: "torch.Tensor", + target_size, + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[list[int]] = 1, + device: Optional["torch.device"] = None, + ): + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (`torch.Tensor`): + Input original image + target_size (`int`): + Target size of the resized image + crop_n_layers (`int`, *optional*, defaults to 0): + If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where + each layer has 2**i_layer number of image crops. + overlap_ratio (`float`, *optional*, defaults to 512/1500): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*, defaults to 32): + Number of points to sample from each crop. + crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + device (`torch.device`, *optional*, defaults to None): + Device to use for the computation. If None, cpu will be used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + image = self._process_image(image) + crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( + image, + target_size, + crop_n_layers, + overlap_ratio, + points_per_crop, + crop_n_points_downscale_factor, + ) + if device is None: + device = torch.device("cpu") + crop_boxes = crop_boxes.to(device) + points_per_crop = points_per_crop.to(device) + # cropped_images stays as torch.Tensor + input_labels = input_labels.to(device) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + def filter_masks( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`torch.Tensor`): + Input masks. + iou_scores (`torch.Tensor`): + List of IoU scores. + original_size (`tuple[int,int]`): + Size of the original image. + cropped_box_image (`torch.Tensor`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + original_height, original_width = original_size + iou_scores = iou_scores.flatten(0, 1) + masks = masks.flatten(0, 1) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the same batch size.") + + if masks.device != iou_scores.device: + iou_scores = iou_scores.to(masks.device) + + batch_size = masks.shape[0] + + keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box(masks) + + keep_mask = ~_is_box_near_crop_edge( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppression + masks = _mask_to_rle(masks) + + return masks, scores, converted_boxes + + def post_process_masks( + self, + masks, + original_sizes, + reshaped_input_sizes, + mask_threshold=0.0, + binarize=True, + pad_size=None, + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[List[torch.Tensor], List[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`): + The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + Returns: + (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + pad_size = self.size if pad_size is None else pad_size + target_image_size = (pad_size["height"], pad_size["width"]) + if isinstance(original_sizes, (torch.Tensor, np.ndarray)): + original_sizes = original_sizes.tolist() + if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)): + reshaped_input_sizes = reshaped_input_sizes.tolist() + + output_masks = [] + for i, original_size in enumerate(original_sizes): + if isinstance(masks[i], np.ndarray): + masks[i] = torch.from_numpy(masks[i]) + elif not isinstance(masks[i], torch.Tensor): + raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`") + interpolated_mask = F_t.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False) + interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]] + interpolated_mask = F_t.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False) + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + output_masks.append(interpolated_mask) + + return output_masks + + def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh): + """ + Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. + + Args: + all_masks (`torch.Tensor`): + List of all predicted segmentation masks + all_scores (`torch.Tensor`): + List of all predicted iou scores + all_boxes (`torch.Tensor`): + List of all bounding boxes of the predicted masks + crops_nms_thresh (`float`): + Threshold for NMS (Non Maximum Suppression) algorithm. + """ + return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh) + + +def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + ) + unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + stability_scores = intersections / unions + return stability_scores + + +def _mask_to_rle(input_mask: "torch.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = input_mask.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + if len(cur_idxs) == 0: + # No changes => either all 0 or all 1 + # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. + if input_mask[i, 0] == 0: + out.append({"size": [height, width], "counts": [height * width]}) + else: + out.append({"size": [height, width], "counts": [0, height * width]}) + continue + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()] + out.append({"size": [height, width], "counts": counts}) + return out + + +def _batched_mask_to_box(masks: "torch.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + # torch.max below raises an error on empty inputs, just skip in this case + + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to Cxheightxwidth + shape = masks.shape + height, width = shape[-2:] + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + out = out.reshape(*shape[:-2], 4) + return out + + +def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + + left, top, _, _ = crop_box + offset = torch.tensor([[left, top, left, top]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + boxes = (boxes + offset).float() + + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return torch.nn.functional.pad(masks, pad, value=0) + + +def _generate_crop_boxes( + image, + target_size: int, # Is it tuple here? + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[list[int]] = 1, +) -> tuple[list[list[int]], list[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]): + Image to generate crops for. + target_size (`int`): + Size of the smallest crop. + crop_n_layers (`int`, *optional*): + If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers + to run, where each layer has 2**i_layer number of image crops. + overlap_ratio (`int`, *optional*): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the + image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*): + Number of points to sample per crop. + crop_n_points_downscale_factor (`int`, *optional*): + The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + + if isinstance(image, list): + raise ValueError("Only one image is allowed for crop generation.") + original_size = image.shape[-2:] + + points_grid = [] + for i in range(crop_n_layers + 1): + n_points = int(points_per_crop / (crop_n_points_downscale_factor**i)) + points_grid.append(_build_point_grid(n_points)) + + crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size) + + cropped_images, point_grid_per_crop = _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size + ) + crop_boxes = torch.tensor(crop_boxes) + crop_boxes = crop_boxes.float() + points_per_crop = torch.stack(point_grid_per_crop) + points_per_crop = points_per_crop.unsqueeze(0).permute(0, 2, 1, 3) + cropped_images = torch.stack(cropped_images) + + input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.int64) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + +def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size): + """ + Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format + consists of the following required indices: + - X: X coordinate of the top left of the bounding box + - Y: Y coordinate of the top left of the bounding box + - W: width of the bounding box + - H: height of the bounding box + """ + crop_boxes, layer_idxs = [], [] + im_height, im_width = original_size + short_side = min(im_height, im_width) + + # Original image + crop_boxes.append([0, 0, im_width, im_height]) + layer_idxs.append(0) + for i_layer in range(crop_n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side)) + crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side)) + + crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)] + + for left, top in product(crop_box_x0, crop_box_y0): + box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def _build_point_grid(n_per_side: int) -> torch.Tensor: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = torch.linspace(offset, 1 - offset, n_per_side) + points_x = torch.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = torch.tile(points_one_side[:, None], (1, n_per_side)) + points = torch.stack([points_x, points_y], dim=-1).reshape(-1, 2) + return points + + +def _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None +): + """ + Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are + also passed. + """ + cropped_images = [] + total_points_per_crop = [] + for i, crop_box in enumerate(crop_boxes): + left, top, right, bottom = crop_box + cropped_im = image[:, top:bottom, left:right] + + cropped_images.append(cropped_im) + + cropped_im_size = cropped_im.shape[-2:] + points_scale = torch.tensor(cropped_im_size).flip(dims=(0,)).unsqueeze(0) + + points = points_grid[layer_idxs[i]] * points_scale + normalized_points = _normalize_coordinates(target_size, points, original_size) + total_points_per_crop.append(normalized_points) + + return cropped_images, total_points_per_crop + + +def _normalize_coordinates( + target_size: int, coords: torch.Tensor, original_size: tuple[int, int], is_bounding_box=False +) -> torch.Tensor: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width) + format. + """ + old_height, old_width = original_size + + scale = target_size * 1.0 / max(old_height, old_width) + new_height, new_width = old_height * scale, old_width * scale + new_width = int(new_width + 0.5) + new_height = int(new_height + 0.5) + + coords = deepcopy(coords).float() + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_width / old_width) + coords[..., 1] = coords[..., 1] * (new_height / old_height) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + +def _rle_to_mask(rle: dict[str, Any]) -> torch.Tensor: + """Compute a binary mask from an uncompressed RLE.""" + height, width = rle["size"] + mask = torch.empty(height * width, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity = not parity + mask = mask.reshape(width, height) + return mask.transpose(0, 1) # Reshape to original shape + + +def _post_process_for_mask_generation(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maximum Suppression) on the outputs. + + Args: + rle_masks (`torch.Tensor`): + binary masks in the RLE format + iou_scores (`torch.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`torch.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + keep_by_nms = batched_nms( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=torch.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes + + +__all__ = ["SamImageProcessorFast"] diff --git a/tests/models/sam/test_image_processing_sam.py b/tests/models/sam/test_image_processing_sam.py new file mode 100644 index 0000000000..c6aef45b15 --- /dev/null +++ b/tests/models/sam/test_image_processing_sam.py @@ -0,0 +1,301 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from datasets import load_dataset + +from transformers.file_utils import is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torchvision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from transformers import SamImageProcessor + + if is_torchvision_available(): + from transformers import SamImageProcessorFast + + +class SamImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_pad=True, + pad_size=None, + mask_size=None, + mask_pad_size=None, + do_resize=True, + size=None, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ): + size = size if size is not None else {"longest_edge": 20} + pad_size = pad_size if pad_size is not None else {"height": 20, "width": 20} + mask_size = mask_size if mask_size is not None else {"longest_edge": 12} + mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 12, "width": 12} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_pad = do_pad + self.pad_size = pad_size + self.mask_size = mask_size + self.mask_pad_size = mask_pad_size + self.do_resize = do_resize + self.size = size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + + def prepare_image_processor_dict(self): + return { + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_normalize": self.do_normalize, + "do_resize": self.do_resize, + "size": self.size, + "do_pad": self.do_pad, + "pad_size": self.pad_size, + "mask_size": self.mask_size, + "mask_pad_size": self.mask_pad_size, + } + + def expected_output_image_shape(self, images): + return self.num_channels, self.pad_size["height"], self.pad_size["width"] + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_single_inputs +def prepare_semantic_single_inputs(): + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + example = ds[0] + return example["image"], example["map"] + + +# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_batch_inputs +def prepare_semantic_batch_inputs(): + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + return list(ds["image"][:2]), list(ds["map"][:2]) + + +@require_torch +@require_vision +class SamImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = SamImageProcessor if is_vision_available() else None + fast_image_processing_class = SamImageProcessorFast if is_torchvision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = SamImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_pad")) + self.assertTrue(hasattr(image_processing, "pad_size")) + self.assertTrue(hasattr(image_processing, "mask_size")) + self.assertTrue(hasattr(image_processing, "mask_pad_size")) + + def test_image_processor_from_dict_with_kwargs(self): + for image_processing_class in self.image_processor_list: + image_processing_class = image_processing_class(**self.image_processor_dict) + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"longest_edge": 20}) + + image_processor = image_processing_class.from_dict(self.image_processor_dict, size={"longest_edge": 42}) + self.assertEqual(image_processor.size, {"longest_edge": 42}) + + def test_call_segmentation_maps(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processor + image_processor = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + maps = [] + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + maps.append(torch.zeros(image.shape[-2:]).long()) + + # Test not batched input + encoding = image_processor(image_inputs[0], maps[0], return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.pad_size["height"], + self.image_processor_tester.pad_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.mask_pad_size["height"], + self.image_processor_tester.mask_pad_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test batched + encoding = image_processor(image_inputs, maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + self.image_processor_tester.pad_size["height"], + self.image_processor_tester.pad_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.mask_pad_size["height"], + self.image_processor_tester.mask_pad_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test not batched input (PIL images) + image, segmentation_map = prepare_semantic_single_inputs() + + encoding = image_processor(image, segmentation_map, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.pad_size["height"], + self.image_processor_tester.pad_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.mask_pad_size["height"], + self.image_processor_tester.mask_pad_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test batched input (PIL images) + images, segmentation_maps = prepare_semantic_batch_inputs() + + encoding = image_processor(images, segmentation_maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 2, + self.image_processor_tester.num_channels, + self.image_processor_tester.pad_size["height"], + self.image_processor_tester.pad_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 2, + self.image_processor_tester.mask_pad_size["height"], + self.image_processor_tester.mask_pad_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_image, dummy_map = prepare_semantic_single_inputs() + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + + self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3 + ) + self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1)) + + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_images, dummy_maps = prepare_semantic_batch_inputs() + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") + + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 + )