diff --git a/docs/source/en/model_doc/mobilevit.md b/docs/source/en/model_doc/mobilevit.md index 6fb69649ee..0ce9f8d21f 100644 --- a/docs/source/en/model_doc/mobilevit.md +++ b/docs/source/en/model_doc/mobilevit.md @@ -95,6 +95,12 @@ If you're interested in submitting a resource to be included here, please feel f - preprocess - post_process_semantic_segmentation +## MobileViTImageProcessorFast + +[[autodoc]] MobileViTImageProcessorFast + - preprocess + - post_process_semantic_segmentation + diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 4586627b91..b8ce8c7280 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -123,8 +123,8 @@ else: ("mllama", ("MllamaImageProcessor",)), ("mobilenet_v1", ("MobileNetV1ImageProcessor", "MobileNetV1ImageProcessorFast")), ("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")), - ("mobilevit", ("MobileViTImageProcessor",)), - ("mobilevitv2", ("MobileViTImageProcessor",)), + ("mobilevit", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")), + ("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")), ("nat", ("ViTImageProcessor", "ViTImageProcessorFast")), ("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")), ("oneformer", ("OneFormerImageProcessor",)), diff --git a/src/transformers/models/mobilevit/__init__.py b/src/transformers/models/mobilevit/__init__.py index 63f4f9c472..6750449a3e 100644 --- a/src/transformers/models/mobilevit/__init__.py +++ b/src/transformers/models/mobilevit/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_mobilevit import * from .feature_extraction_mobilevit import * from .image_processing_mobilevit import * + from .image_processing_mobilevit_fast import * from .modeling_mobilevit import * from .modeling_tf_mobilevit import * else: diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py b/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py new file mode 100644 index 0000000000..251666c801 --- /dev/null +++ b/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for MobileViT.""" + +from typing import Optional + +import torch + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + ChannelDimension, + PILImageResampling, + is_torch_tensor, + make_list_of_images, + pil_torch_interpolation_mapping, + validate_kwargs, +) +from ...processing_utils import Unpack +from ...utils import auto_docstring + + +class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): + Whether to flip the color channels from RGB to BGR or vice versa. + """ + + do_flip_channel_order: Optional[bool] + + +@auto_docstring +class MobileViTImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 256, "width": 256} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = None + do_convert_rgb = None + do_flip_channel_order = True + valid_kwargs = MobileVitFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def _preprocess( + self, + images, + do_resize: bool, + size: Optional[dict], + interpolation: Optional[str], + do_rescale: bool, + rescale_factor: Optional[float], + do_center_crop: bool, + crop_size: Optional[dict], + do_flip_channel_order: bool, + disable_grouping: bool, + return_tensors: Optional[str], + **kwargs, + ): + processed_images = [] + + # Group images by shape for more efficient batch processing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + + # Process each group of images with the same shape + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + + # Reorder images to original sequence + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group again after resizing (in case resize produced different sizes) + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(image=stacked_images, size=crop_size) + if do_rescale: + stacked_images = self.rescale(image=stacked_images, scale=rescale_factor) + if do_flip_channel_order: + # For batched images, we need to handle them all at once + if stacked_images.ndim > 3 and stacked_images.shape[1] >= 3: + # Flip RGB → BGR for batched images + flipped = stacked_images.clone() + flipped[:, 0:3] = stacked_images[:, [2, 1, 0], ...] + stacked_images = flipped + + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + + # Stack all processed images if return_tensors is specified + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return processed_images + + def _preprocess_segmentation_maps( + self, + segmentation_maps, + **kwargs, + ): + """Preprocesses segmentation maps.""" + processed_segmentation_maps = [] + for segmentation_map in segmentation_maps: + segmentation_map = self._process_image( + segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST + ) + + if segmentation_map.ndim == 2: + segmentation_map = segmentation_map[None, ...] + + processed_segmentation_maps.append(segmentation_map) + + kwargs["do_rescale"] = False + kwargs["do_flip_channel_order"] = False + kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST] + processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) + + processed_segmentation_maps = processed_segmentation_maps.squeeze(1) + + processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) + return processed_segmentation_maps + + @auto_docstring + def preprocess( + self, + images, + segmentation_maps=None, + **kwargs: Unpack[MobileVitFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + # Extract parameters that are only used for preparing the input images + do_convert_rgb = kwargs.pop("do_convert_rgb") + input_data_format = kwargs.pop("input_data_format") + device = kwargs.pop("device") + # Prepare input images + images = self._prepare_input_images( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + + # Prepare segmentation maps + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) + + # Update kwargs that need further processing before being validated + kwargs = self._further_process_kwargs(**kwargs) + + # Validate kwargs + self._validate_preprocess_kwargs(**kwargs) + + # torch resize uses interpolation instead of resample + resample = kwargs.pop("resample") + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + # Pop kwargs that are not needed in _preprocess + kwargs.pop("default_to_square") + kwargs.pop("data_format") + + images = self._preprocess( + images=images, + **kwargs, + ) + + if segmentation_maps is not None: + segmentation_maps = self._preprocess_segmentation_maps( + segmentation_maps=segmentation_maps, + **kwargs, + ) + return BatchFeature(data={"pixel_values": images, "labels": segmentation_maps}) + + return BatchFeature(data={"pixel_values": images}) + + def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + +__all__ = ["MobileViTImageProcessorFast"] diff --git a/tests/models/mobilevit/test_image_processing_mobilevit.py b/tests/models/mobilevit/test_image_processing_mobilevit.py index 7df498176d..df5caa6b7f 100644 --- a/tests/models/mobilevit/test_image_processing_mobilevit.py +++ b/tests/models/mobilevit/test_image_processing_mobilevit.py @@ -15,10 +15,11 @@ import unittest +import requests from datasets import load_dataset from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -27,8 +28,13 @@ if is_torch_available(): import torch if is_vision_available(): + from PIL import Image + from transformers import MobileViTImageProcessor + if is_torchvision_available(): + from transformers import MobileViTImageProcessorFast + class MobileViTImageProcessingTester: def __init__( @@ -98,6 +104,7 @@ def prepare_semantic_batch_inputs(): @require_vision class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = MobileViTImageProcessor if is_vision_available() else None + fast_image_processing_class = MobileViTImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -108,124 +115,155 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_center_crop")) - self.assertTrue(hasattr(image_processing, "center_crop")) - self.assertTrue(hasattr(image_processing, "do_flip_channel_order")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_flip_channel_order")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 20}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + for image_processing_class in self.image_processor_list: + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) def test_call_segmentation_maps(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) - maps = [] - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) - maps.append(torch.zeros(image.shape[-2:]).long()) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + maps = [] + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + maps.append(torch.zeros(image.shape[-2:]).long()) - # Test not batched input - encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - 1, - self.image_processor_tester.num_channels, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - 1, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + # Test not batched input + encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) - # Test batched - encoding = image_processing(image_inputs, maps, return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + # Test batched + encoding = image_processing(image_inputs, maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) - # Test not batched input (PIL images) + # Test not batched input (PIL images) + image, segmentation_map = prepare_semantic_single_inputs() + + encoding = image_processing(image, segmentation_map, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test batched input (PIL images) + images, segmentation_maps = prepare_semantic_batch_inputs() + + encoding = image_processing(images, segmentation_maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 2, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 2, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + @require_vision + @require_torch + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + # Test with single image + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) + + # Test with single image and segmentation map image, segmentation_map = prepare_semantic_single_inputs() - encoding = image_processing(image, segmentation_map, return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - 1, - self.image_processor_tester.num_channels, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - 1, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) - - # Test batched input (PIL images) - images, segmentation_maps = prepare_semantic_batch_inputs() - - encoding = image_processing(images, segmentation_maps, return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - 2, - self.image_processor_tester.num_channels, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - 2, - self.image_processor_tester.crop_size["height"], - self.image_processor_tester.crop_size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + encoding_slow = image_processor_slow(image, segmentation_map, return_tensors="pt") + encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) + torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3)