From a7d2bbaaa8aac64f7c1ee8c1421cfe84b38359a4 Mon Sep 17 00:00:00 2001 From: Zeeshan Khan Suri <5270999+zshn25@users.noreply.github.com> Date: Wed, 16 Apr 2025 21:59:24 +0200 Subject: [PATCH] Add EfficientNet Image PreProcessor (#37055) * added efficientnet image preprocessor but tests fail * ruff checks pass * ruff formatted * properly pass rescale_offset through the functions * - corrected indentation, ordering of methods - reshape test passes when casted to float64 - equivalence test doesn't pass * all tests now pass - changes order of rescale, normalize acc to slow - rescale_offset defaults to False acc to slow - resample was causing difference in fast and slow. Changing test to bilinear resolves this difference * ruff reformat * F.InterpolationMode.NEAREST_EXACT gives TypeError: Object of type InterpolationMode is not JSON serializable * fixes offset not being applied when do_rescale and do_normalization are both true * - using nearest_exact sampling - added tests for rescale + normalize * resolving reviews --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> --- docs/source/en/model_doc/efficientnet.md | 5 + src/transformers/image_utils.py | 2 +- .../models/auto/image_processing_auto.py | 4 +- .../models/efficientnet/__init__.py | 1 + .../image_processing_efficientnet_fast.py | 226 ++++++++++++++++++ .../test_image_processing_efficientnet.py | 104 ++++++-- 6 files changed, 321 insertions(+), 21 deletions(-) create mode 100644 src/transformers/models/efficientnet/image_processing_efficientnet_fast.py diff --git a/docs/source/en/model_doc/efficientnet.md b/docs/source/en/model_doc/efficientnet.md index a34378fa47..17a96aeb5a 100644 --- a/docs/source/en/model_doc/efficientnet.md +++ b/docs/source/en/model_doc/efficientnet.md @@ -43,6 +43,11 @@ The original code can be found [here](https://github.com/tensorflow/tpu/tree/mas [[autodoc]] EfficientNetImageProcessor - preprocess +## EfficientNetImageProcessorFast + +[[autodoc]] EfficientNetImageProcessorFast + - preprocess + ## EfficientNetModel [[autodoc]] EfficientNetModel diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 21dbbe374c..b9241b057f 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -65,7 +65,7 @@ if is_vision_available(): from torchvision.transforms import InterpolationMode pil_torch_interpolation_mapping = { - PILImageResampling.NEAREST: InterpolationMode.NEAREST, + PILImageResampling.NEAREST: InterpolationMode.NEAREST_EXACT, PILImageResampling.BOX: InterpolationMode.BOX, PILImageResampling.BILINEAR: InterpolationMode.BILINEAR, PILImageResampling.HAMMING: InterpolationMode.HAMMING, diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index d04c6b6ee4..082a781996 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -56,7 +56,7 @@ if TYPE_CHECKING: else: IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( [ - ("align", ("EfficientNetImageProcessor",)), + ("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), ("aria", ("AriaImageProcessor",)), ("beit", ("BeitImageProcessor",)), ("bit", ("BitImageProcessor", "BitImageProcessorFast")), @@ -83,7 +83,7 @@ else: ("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")), ("dpt", ("DPTImageProcessor",)), ("efficientformer", ("EfficientFormerImageProcessor",)), - ("efficientnet", ("EfficientNetImageProcessor",)), + ("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), ("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")), ("focalnet", ("BitImageProcessor", "BitImageProcessorFast")), ("fuyu", ("FuyuImageProcessor",)), diff --git a/src/transformers/models/efficientnet/__init__.py b/src/transformers/models/efficientnet/__init__.py index 68a2825c70..24d58e8116 100644 --- a/src/transformers/models/efficientnet/__init__.py +++ b/src/transformers/models/efficientnet/__init__.py @@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure if TYPE_CHECKING: from .configuration_efficientnet import * from .image_processing_efficientnet import * + from .image_processing_efficientnet_fast import * from .modeling_efficientnet import * else: import sys diff --git a/src/transformers/models/efficientnet/image_processing_efficientnet_fast.py b/src/transformers/models/efficientnet/image_processing_efficientnet_fast.py new file mode 100644 index 0000000000..fb63956401 --- /dev/null +++ b/src/transformers/models/efficientnet/image_processing_efficientnet_fast.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for EfficientNet.""" + +from functools import lru_cache +from typing import Optional, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, +) +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class EfficientNetFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + rescale_offset: bool + include_top: bool + + +@add_start_docstrings( + "Constructs a fast EfficientNet image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class EfficientNetImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.NEAREST + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 346, "width": 346} + crop_size = {"height": 289, "width": 289} + do_resize = True + do_center_crop = False + do_rescale = True + rescale_factor = 1 / 255 + rescale_offset = False + do_normalize = True + include_top = True + valid_kwargs = EfficientNetFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[EfficientNetFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def rescale( + self, + image: "torch.Tensor", + scale: float, + offset: Optional[bool] = True, + **kwargs, + ) -> "torch.Tensor": + """ + Rescale an image by a scale factor. + + If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is + 1/127.5, the image is rescaled between [-1, 1]. + image = image * scale - 1 + + If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1]. + image = image * scale + + Args: + image (`torch.Tensor`): + Image to rescale. + scale (`float`): + The scaling factor to rescale pixel values by. + offset (`bool`, *optional*): + Whether to scale the image in both negative and positive directions. + + Returns: + `torch.Tensor`: The rescaled image. + """ + + rescaled_image = image * scale + + if offset: + rescaled_image -= 1 + + return rescaled_image + + @lru_cache(maxsize=10) + def _fuse_mean_std_and_rescale_factor( + self, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + device: Optional["torch.device"] = None, + rescale_offset: Optional[bool] = False, + ) -> tuple: + if do_rescale and do_normalize and not rescale_offset: + # Fused rescale and normalize + image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) + image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) + do_rescale = False + return image_mean, image_std, do_rescale + + def rescale_and_normalize( + self, + images: "torch.Tensor", + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Union[float, list[float]], + image_std: Union[float, list[float]], + rescale_offset: bool = False, + ) -> "torch.Tensor": + """ + Rescale and normalize images. + """ + image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor( + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + device=images.device, + rescale_offset=rescale_offset, + ) + # if/elif as we use fused rescale and normalize if both are set to True + if do_rescale: + images = self.rescale(images, rescale_factor, rescale_offset) + if do_normalize: + images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std) + + return images + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + rescale_offset: bool, + do_normalize: bool, + include_top: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std, rescale_offset + ) + if include_top: + stacked_images = self.normalize(stacked_images, 0, image_std) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + rescale_offset (`bool`, *optional*, defaults to `self.rescale_offset`): + Whether to rescale the image between [-max_range/2, scale_range/2] instead of [0, scale_range]. + include_top (`bool`, *optional*, defaults to `self.include_top`): + Normalize the image again with the standard deviation only for image classification if set to True. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[EfficientNetFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + +__all__ = ["EfficientNetImageProcessorFast"] diff --git a/tests/models/efficientnet/test_image_processing_efficientnet.py b/tests/models/efficientnet/test_image_processing_efficientnet.py index 577798e9ae..cb8fc8d922 100644 --- a/tests/models/efficientnet/test_image_processing_efficientnet.py +++ b/tests/models/efficientnet/test_image_processing_efficientnet.py @@ -17,15 +17,26 @@ import unittest import numpy as np +from transformers.image_utils import PILImageResampling from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import ( + is_torch_available, + is_torchvision_available, + is_vision_available, +) from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs +if is_torch_available(): + import torch + if is_vision_available(): from transformers import EfficientNetImageProcessor + if is_torchvision_available(): + from transformers import EfficientNetImageProcessorFast + class EfficientNetImageProcessorTester: def __init__( @@ -41,6 +52,10 @@ class EfficientNetImageProcessorTester: do_normalize=True, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], + do_rescale=True, + rescale_offset=True, + rescale_factor=1 / 127.5, + resample=PILImageResampling.BILINEAR, # NEAREST is too different between PIL and torchvision ): size = size if size is not None else {"height": 18, "width": 18} self.parent = parent @@ -54,6 +69,7 @@ class EfficientNetImageProcessorTester: self.do_normalize = do_normalize self.image_mean = image_mean self.image_std = image_std + self.resample = resample def prepare_image_processor_dict(self): return { @@ -62,6 +78,7 @@ class EfficientNetImageProcessorTester: "do_normalize": self.do_normalize, "do_resize": self.do_resize, "size": self.size, + "resample": self.resample, } def expected_output_image_shape(self, images): @@ -83,6 +100,7 @@ class EfficientNetImageProcessorTester: @require_vision class EfficientNetImageProcessorTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = EfficientNetImageProcessor if is_vision_available() else None + fast_image_processing_class = EfficientNetImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -93,30 +111,80 @@ class EfficientNetImageProcessorTest(ImageProcessingTestMixin, unittest.TestCase return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 18, "width": 18}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 18, "width": 18}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) def test_rescale(self): # EfficientNet optionally rescales between -1 and 1 instead of the usual 0 and 1 image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32) - image_processor = self.image_processing_class(**self.image_processor_dict) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + if image_processing_class == EfficientNetImageProcessorFast: + image = torch.from_numpy(image) - rescaled_image = image_processor.rescale(image, scale=1 / 127.5) - expected_image = (image * (1 / 127.5)).astype(np.float32) - 1 - self.assertTrue(np.allclose(rescaled_image, expected_image)) + # Scale between [-1, 1] with rescale_factor 1/127.5 and rescale_offset=True + rescaled_image = image_processor.rescale(image, scale=1 / 127.5, offset=True) + expected_image = (image * (1 / 127.5)) - 1 + self.assertTrue(torch.allclose(rescaled_image, expected_image)) - rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False) - expected_image = (image / 255.0).astype(np.float32) - self.assertTrue(np.allclose(rescaled_image, expected_image)) + # Scale between [0, 1] with rescale_factor 1/255 and rescale_offset=True + rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False) + expected_image = image / 255.0 + self.assertTrue(torch.allclose(rescaled_image, expected_image)) + + else: + rescaled_image = image_processor.rescale(image, scale=1 / 127.5, dtype=np.float64) + expected_image = (image * (1 / 127.5)).astype(np.float64) - 1 + self.assertTrue(np.allclose(rescaled_image, expected_image)) + + rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False, dtype=np.float64) + expected_image = (image / 255.0).astype(np.float64) + self.assertTrue(np.allclose(rescaled_image, expected_image)) + + @require_vision + @require_torch + def test_rescale_normalize(self): + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + image = torch.arange(0, 256, 1, dtype=torch.uint8).reshape(1, 8, 32).repeat(3, 1, 1) + image_mean_0 = (0.0, 0.0, 0.0) + image_std_0 = (1.0, 1.0, 1.0) + image_mean_1 = (0.5, 0.5, 0.5) + image_std_1 = (0.5, 0.5, 0.5) + + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + # Rescale between [-1, 1] with rescale_factor=1/127.5 and rescale_offset=True. Then normalize + rescaled_normalized = image_processor_fast.rescale_and_normalize( + image, True, 1 / 127.5, True, image_mean_0, image_std_0, True + ) + expected_image = (image * (1 / 127.5)) - 1 + expected_image = (expected_image - torch.tensor(image_mean_0).view(3, 1, 1)) / torch.tensor(image_std_0).view( + 3, 1, 1 + ) + self.assertTrue(torch.allclose(rescaled_normalized, expected_image, rtol=1e-3)) + + # Rescale between [0, 1] with rescale_factor=1/255 and rescale_offset=False. Then normalize + rescaled_normalized = image_processor_fast.rescale_and_normalize( + image, True, 1 / 255, True, image_mean_1, image_std_1, False + ) + expected_image = image * (1 / 255.0) + expected_image = (expected_image - torch.tensor(image_mean_1).view(3, 1, 1)) / torch.tensor(image_std_1).view( + 3, 1, 1 + ) + self.assertTrue(torch.allclose(rescaled_normalized, expected_image, rtol=1e-3))