diff --git a/docs/source/en/model_doc/beit.md b/docs/source/en/model_doc/beit.md index 24dfabf682..e40fbdc9c8 100644 --- a/docs/source/en/model_doc/beit.md +++ b/docs/source/en/model_doc/beit.md @@ -150,6 +150,11 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] BeitImageProcessor - preprocess - post_process_semantic_segmentation +## BeitImageProcessorFast + +[[autodoc]] BeitImageProcessorFast + - preprocess + - post_process_semantic_segmentation diff --git a/docs/source/ja/model_doc/beit.md b/docs/source/ja/model_doc/beit.md index 45eb1efa5d..e0b94693a3 100644 --- a/docs/source/ja/model_doc/beit.md +++ b/docs/source/ja/model_doc/beit.md @@ -105,6 +105,11 @@ BEiT の使用を開始するのに役立つ公式 Hugging Face およびコミ [[autodoc]] BeitImageProcessor - preprocess + +## BeitImageProcessorFast + +[[autodoc]] BeitImageProcessorFast + - preprocess - post_process_semantic_segmentation ## BeitModel diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index b4d82dd874..8a57682de7 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -57,8 +57,8 @@ else: IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( [ ("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), - ("aria", ("AriaImageProcessor",)), - ("beit", ("BeitImageProcessor",)), + ("aria", ("AriaImageProcessor")), + ("beit", ("BeitImageProcessor", "BeitImageProcessorFast")), ("bit", ("BitImageProcessor", "BitImageProcessorFast")), ("blip", ("BlipImageProcessor", "BlipImageProcessorFast")), ("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")), @@ -71,7 +71,7 @@ else: ("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), - ("data2vec-vision", ("BeitImageProcessor",)), + ("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")), ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")), ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")), ("depth_anything", ("DPTImageProcessor",)), diff --git a/src/transformers/models/beit/__init__.py b/src/transformers/models/beit/__init__.py index 44f838a9a5..3f412a3500 100644 --- a/src/transformers/models/beit/__init__.py +++ b/src/transformers/models/beit/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_beit import * from .feature_extraction_beit import * from .image_processing_beit import * + from .image_processing_beit_fast import * from .modeling_beit import * from .modeling_flax_beit import * else: diff --git a/src/transformers/models/beit/image_processing_beit_fast.py b/src/transformers/models/beit/image_processing_beit_fast.py new file mode 100644 index 0000000000..50dd6fe7f5 --- /dev/null +++ b/src/transformers/models/beit/image_processing_beit_fast.py @@ -0,0 +1,284 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Beit.""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torchvision.transforms import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + is_torch_tensor, + make_list_of_images, + pil_torch_interpolation_mapping, + validate_kwargs, +) +from ...processing_utils import Unpack +from ...utils import TensorType, add_start_docstrings +from ...utils.deprecation import deprecate_kwarg + + +class BeitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + do_reduce_labels: Optional[bool] + + +@add_start_docstrings( + "Constructs a fast Beit image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + """, +) +class BeitImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 224, "width": 224} + default_to_square = True + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = False + do_rescale = True + do_normalize = True + do_reduce_labels = False + valid_kwargs = BeitFastImageProcessorKwargs + + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs + """ + image_processor_dict = image_processor_dict.copy() + if "reduce_labels" in image_processor_dict: + image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels") + return super().from_dict(image_processor_dict, **kwargs) + + def reduce_label(self, labels: list["torch.Tensor"]): + for idx in range(len(labels)): + label = labels[idx] + label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label) + label = label - 1 + label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label) + labels[idx] = label + + return label + + def _preprocess( + self, + images: list["torch.Tensor"], + do_reduce_labels: bool, + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + if do_reduce_labels: + images = self.reduce_label(images) + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return processed_images + + def _preprocess_segmentation_maps( + self, + segmentation_maps, + **kwargs, + ): + """Preprocesses a single segmentation map.""" + processed_segmentation_maps = [] + for segmentation_map in segmentation_maps: + segmentation_map = self._process_image( + segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST + ) + + if segmentation_map.ndim == 2: + segmentation_map = segmentation_map[None, ...] + + processed_segmentation_maps.append(segmentation_map) + + kwargs["do_normalize"] = False + kwargs["do_rescale"] = False + kwargs["input_data_format"] = ChannelDimension.FIRST + processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) + + processed_segmentation_maps = processed_segmentation_maps.squeeze(1) + + processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) + return processed_segmentation_maps + + def __call__(self, images, segmentation_maps=None, **kwargs): + # Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both + # be passed in as positional arguments. + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0") + @add_start_docstrings( + "Constructs a fast Beit image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + """, + ) + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[DefaultFastImageProcessorKwargs], + ) -> BatchFeature: + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + # Extract parameters that are only used for preparing the input images + do_convert_rgb = kwargs.pop("do_convert_rgb") + input_data_format = kwargs.pop("input_data_format") + device = kwargs.pop("device") + # Prepare input images + images = self._prepare_input_images( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + + # Prepare segmentation maps + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) + + # Update kwargs that need further processing before being validated + kwargs = self._further_process_kwargs(**kwargs) + + # Validate kwargs + self._validate_preprocess_kwargs(**kwargs) + + # torch resize uses interpolation instead of resample + resample = kwargs.pop("resample") + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + # Pop kwargs that are not needed in _preprocess + kwargs.pop("default_to_square") + kwargs.pop("data_format") + + images = self._preprocess( + images=images, + **kwargs, + ) + data = {"pixel_values": images} + + if segmentation_maps is not None: + segmentation_maps = self._preprocess_segmentation_maps( + segmentation_maps=segmentation_maps, + **kwargs, + ) + data["labels"] = segmentation_maps + + return BatchFeature(data=data) + + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): + """ + Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`BeitForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + # TODO: add support for other frameworks + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + +__all__ = ["BeitImageProcessorFast"] diff --git a/tests/models/beit/test_image_processing_beit.py b/tests/models/beit/test_image_processing_beit.py index 48b6518a55..5a46279a68 100644 --- a/tests/models/beit/test_image_processing_beit.py +++ b/tests/models/beit/test_image_processing_beit.py @@ -18,7 +18,7 @@ import unittest from datasets import load_dataset from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -31,6 +31,9 @@ if is_vision_available(): from transformers import BeitImageProcessor + if is_torchvision_available(): + from transformers import BeitImageProcessorFast + class BeitImageProcessingTester: def __init__( @@ -118,6 +121,7 @@ def prepare_semantic_batch_inputs(): @require_vision class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = BeitImageProcessor if is_vision_available() else None + fast_image_processing_class = BeitImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -128,159 +132,196 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_center_crop")) - self.assertTrue(hasattr(image_processing, "center_crop")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_reduce_labels")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_reduce_labels")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 20, "width": 20}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - self.assertEqual(image_processor.do_reduce_labels, False) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 20, "width": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + self.assertEqual(image_processor.do_reduce_labels, False) - image_processor = self.image_processing_class.from_dict( - self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True - ) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) - self.assertEqual(image_processor.do_reduce_labels, True) + image_processor = image_processing_class.from_dict( + self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True + ) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + self.assertEqual(image_processor.do_reduce_labels, True) def test_call_segmentation_maps(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) - maps = [] - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) - maps.append(torch.zeros(image.shape[-2:]).long()) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + maps = [] + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + maps.append(torch.zeros(image.shape[-2:]).long()) - # Test not batched input - encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - 1, - self.image_processor_tester.num_channels, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - 1, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + # Test not batched input + encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) - # Test batched - encoding = image_processing(image_inputs, maps, return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + # Test batched + encoding = image_processing(image_inputs, maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) - # Test not batched input (PIL images) - image, segmentation_map = prepare_semantic_single_inputs() + # Test not batched input (PIL images) + image, segmentation_map = prepare_semantic_single_inputs() - encoding = image_processing(image, segmentation_map, return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - 1, - self.image_processor_tester.num_channels, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - 1, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + encoding = image_processing(image, segmentation_map, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) - # Test batched input (PIL images) - images, segmentation_maps = prepare_semantic_batch_inputs() + # Test batched input (PIL images) + images, segmentation_maps = prepare_semantic_batch_inputs() - encoding = image_processing(images, segmentation_maps, return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - 2, - self.image_processor_tester.num_channels, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - 2, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + encoding = image_processing(images, segmentation_maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 2, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 2, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) def test_reduce_labels(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) - # ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150 - image, map = prepare_semantic_single_inputs() - encoding = image_processing(image, map, return_tensors="pt") - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 150) + # ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150 + image, map = prepare_semantic_single_inputs() + encoding = image_processing(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 150) - image_processing.do_reduce_labels = True - encoding = image_processing(image, map, return_tensors="pt") - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + image_processing.do_reduce_labels = True + encoding = image_processing(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) - def test_removed_deprecated_kwargs(self): - image_processor_dict = dict(self.image_processor_dict) - image_processor_dict.pop("do_reduce_labels", None) - image_processor_dict["reduce_labels"] = True + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") - # test we are able to create the image processor with the deprecated kwargs - image_processor = self.image_processing_class(**image_processor_dict) - self.assertEqual(image_processor.do_reduce_labels, True) + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") - # test we still support reduce_labels with config - image_processor = self.image_processing_class.from_dict(image_processor_dict) - self.assertEqual(image_processor.do_reduce_labels, True) + dummy_image, dummy_map = prepare_semantic_single_inputs() + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + + self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3 + ) + self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1)) + + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + dummy_images, dummy_maps = prepare_semantic_batch_inputs() + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") + + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 + )