From e43f168eb3951ccd39991442ca9a94dc8987c65b Mon Sep 17 00:00:00 2001 From: Parteek Date: Mon, 14 Apr 2025 20:37:36 +0530 Subject: [PATCH] Add Fast LeViT Processor (#37154) * Add Fast LeViT Processor * Update levit.md * Update src/transformers/models/levit/image_processing_levit_fast.py Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> * ruff check --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> --- docs/source/en/model_doc/levit.md | 5 + .../models/auto/image_processing_auto.py | 2 +- src/transformers/models/levit/__init__.py | 1 + .../levit/image_processing_levit_fast.py | 101 ++++++++++++++++++ .../levit/test_image_processing_levit.py | 34 +++--- 5 files changed, 128 insertions(+), 15 deletions(-) create mode 100644 src/transformers/models/levit/image_processing_levit_fast.py diff --git a/docs/source/en/model_doc/levit.md b/docs/source/en/model_doc/levit.md index af42c1533e..f794f7902f 100644 --- a/docs/source/en/model_doc/levit.md +++ b/docs/source/en/model_doc/levit.md @@ -94,6 +94,11 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] LevitImageProcessor - preprocess +## LevitImageProcessorFast + + [[autodoc]] LevitImageProcessorFast + - preprocess + ## LevitModel [[autodoc]] LevitModel diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 4439d75638..6387baa20c 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -104,7 +104,7 @@ else: ("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")), ("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")), - ("levit", ("LevitImageProcessor",)), + ("levit", ("LevitImageProcessor", "LevitImageProcessorFast")), ("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")), ("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")), ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")), diff --git a/src/transformers/models/levit/__init__.py b/src/transformers/models/levit/__init__.py index ab009d9312..d3ae097b66 100644 --- a/src/transformers/models/levit/__init__.py +++ b/src/transformers/models/levit/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_levit import * from .feature_extraction_levit import * from .image_processing_levit import * + from .image_processing_levit_fast import * from .modeling_levit import * else: import sys diff --git a/src/transformers/models/levit/image_processing_levit_fast.py b/src/transformers/models/levit/image_processing_levit_fast.py new file mode 100644 index 0000000000..87b0d0ba3e --- /dev/null +++ b/src/transformers/models/levit/image_processing_levit_fast.py @@ -0,0 +1,101 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for LeViT.""" + +from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast, SizeDict +from ...image_transforms import ( + ChannelDimension, + get_resize_output_image_size, +) +from ...image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling +from ...utils import add_start_docstrings, is_torch_available, is_torchvision_available, is_torchvision_v2_available + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +@add_start_docstrings( + "Constructs a fast Levit image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class LevitImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + do_convert_rgb = None + + def resize( + self, + image: torch.Tensor, + size: SizeDict, + interpolation: "F.InterpolationMode" = None, + **kwargs, + ) -> torch.Tensor: + """ + Resize an image. + + If size is a dict with keys "width" and "height", the image will be resized to `(size["height"], + size["width"])`. + + If size is a dict with key "shortest_edge", the shortest edge value `c` is rescaled to `int(c * (256/224))`. + The smaller edge of the image will be matched to this value i.e, if height > width, then image will be rescaled + to `(size["shortest_egde"] * height / width, size["shortest_egde"])`. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Size of the output image after resizing. If size is a dict with keys "width" and "height", the image + will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value + `c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value + i.e, if height > width, then image will be rescaled to (size * height / width, size). + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BICUBIC`): + Resampling filter to use when resiizing the image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BICUBIC + if size.shortest_edge: + shortest_edge = int((256 / 224) * size["shortest_edge"]) + new_size = get_resize_output_image_size( + image, size=shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST + ) + elif size.height and size.width: + new_size = (size.height, size.width) + else: + raise ValueError( + f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size.keys()} {size.keys()}." + ) + return F.resize( + image, + size=new_size, + interpolation=interpolation, + **kwargs, + ) + + +__all__ = ["LevitImageProcessorFast"] diff --git a/tests/models/levit/test_image_processing_levit.py b/tests/models/levit/test_image_processing_levit.py index 6a8f76a465..beb3c77c15 100644 --- a/tests/models/levit/test_image_processing_levit.py +++ b/tests/models/levit/test_image_processing_levit.py @@ -16,7 +16,7 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -24,6 +24,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im if is_vision_available(): from transformers import LevitImageProcessor + if is_torchvision_available(): + from transformers import LevitImageProcessorFast + class LevitImageProcessingTester: def __init__( @@ -88,6 +91,7 @@ class LevitImageProcessingTester: @require_vision class LevitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = LevitImageProcessor if is_vision_available() else None + fast_image_processing_class = LevitImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -98,19 +102,21 @@ class LevitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "do_center_crop")) - self.assertTrue(hasattr(image_processing, "size")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "size")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 18}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 18}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})