diff --git a/docs/source/en/model_doc/swin2sr.md b/docs/source/en/model_doc/swin2sr.md index 136f1a1c1e..3ea713fdc7 100644 --- a/docs/source/en/model_doc/swin2sr.md +++ b/docs/source/en/model_doc/swin2sr.md @@ -50,6 +50,11 @@ A demo Space for image super-resolution with SwinSR can be found [here](https:// [[autodoc]] Swin2SRImageProcessor - preprocess +## Swin2SRImageProcessorFast + +[[autodoc]] Swin2SRImageProcessorFast + - preprocess + ## Swin2SRConfig [[autodoc]] Swin2SRConfig diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 8a57682de7..41d446cf4e 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -150,7 +150,7 @@ else: ("superglue", ("SuperGlueImageProcessor",)), ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")), ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")), - ("swin2sr", ("Swin2SRImageProcessor",)), + ("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")), ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")), ("table-transformer", ("DetrImageProcessor",)), ("timesformer", ("VideoMAEImageProcessor",)), diff --git a/src/transformers/models/swin2sr/__init__.py b/src/transformers/models/swin2sr/__init__.py index cc7ea677b5..082570230c 100644 --- a/src/transformers/models/swin2sr/__init__.py +++ b/src/transformers/models/swin2sr/__init__.py @@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure if TYPE_CHECKING: from .configuration_swin2sr import * from .image_processing_swin2sr import * + from .image_processing_swin2sr_fast import * from .modeling_swin2sr import * else: import sys diff --git a/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py b/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py new file mode 100644 index 0000000000..9dd056af1e --- /dev/null +++ b/src/transformers/models/swin2sr/image_processing_swin2sr_fast.py @@ -0,0 +1,138 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Swin2SR.""" + +from typing import List, Optional, Union + +from ...image_processing_utils import ( + BatchFeature, + ChannelDimension, + get_image_size, +) +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ImageInput +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class Swin2SRFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + do_pad: Optional[bool] + pad_size: Optional[int] + + +@add_start_docstrings( + "Constructs a fast Swin2SR image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to make the height and width divisible by `window_size`. + pad_size (`int`, *optional*, defaults to `8`): + The size of the sliding window for the local attention. + """, +) +class Swin2SRImageProcessorFast(BaseImageProcessorFast): + do_rescale = True + rescale_factor = 1 / 255 + do_pad = True + pad_size = 8 + valid_kwargs = Swin2SRFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[Swin2SRFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to make the height and width divisible by `window_size`. + pad_size (`int`, *optional*, defaults to `8`): + The size of the sliding window for the local attention. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[Swin2SRFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def pad(self, images: "torch.Tensor", size: int) -> "torch.Tensor": + """ + Pad an image to make the height and width divisible by `size`. + + Args: + images (`torch.Tensor`): + Images to pad. + size (`int`): + The size to make the height and width divisible by. + + Returns: + `torch.Tensor`: The padded images. + """ + height, width = get_image_size(images, ChannelDimension.FIRST) + pad_height = (height // size + 1) * size - height + pad_width = (width // size + 1) * size - width + + return F.pad( + images, + (0, 0, pad_width, pad_height), + padding_mode="symmetric", + ) + + def _preprocess( + self, + images: List["torch.Tensor"], + do_rescale: bool, + rescale_factor: float, + do_pad: bool, + pad_size: int, + return_tensors: Optional[Union[str, TensorType]], + interpolation: Optional["F.InterpolationMode"], + **kwargs, + ) -> BatchFeature: + grouped_images, grouped_images_index = group_images_by_shape(images) + processed_image_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_rescale: + stacked_images = self.rescale(stacked_images, scale=rescale_factor) + if do_pad: + stacked_images = self.pad(stacked_images, size=pad_size) + processed_image_grouped[shape] = stacked_images + processed_images = reorder_images(processed_image_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["Swin2SRImageProcessorFast"] diff --git a/tests/models/swin2sr/test_image_processing_swin2sr.py b/tests/models/swin2sr/test_image_processing_swin2sr.py index ec69d7f35b..f4e70d1b0e 100644 --- a/tests/models/swin2sr/test_image_processing_swin2sr.py +++ b/tests/models/swin2sr/test_image_processing_swin2sr.py @@ -18,7 +18,7 @@ import unittest import numpy as np from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -30,6 +30,9 @@ if is_vision_available(): from PIL import Image from transformers import Swin2SRImageProcessor + + if is_torchvision_available(): + from transformers import Swin2SRImageProcessorFast from transformers.image_transforms import get_image_size @@ -97,6 +100,7 @@ class Swin2SRImageProcessingTester: @require_vision class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = Swin2SRImageProcessor if is_vision_available() else None + fast_image_processing_class = Swin2SRImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -107,11 +111,12 @@ class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processor = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processor, "do_rescale")) - self.assertTrue(hasattr(image_processor, "rescale_factor")) - self.assertTrue(hasattr(image_processor, "do_pad")) - self.assertTrue(hasattr(image_processor, "pad_size")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_pad")) + self.assertTrue(hasattr(image_processing, "pad_size")) def calculate_expected_size(self, image): old_height, old_width = get_image_size(image) @@ -181,3 +186,18 @@ class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + @unittest.skip(reason="No speed gain on CPU due to minimal processing.") + def test_fast_is_faster_than_slow(self): + pass + + def test_slow_fast_equivalence_batched(self): + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoded_slow = image_processor_slow(image_inputs, return_tensors="pt").pixel_values + encoded_fast = image_processor_fast(image_inputs, return_tensors="pt").pixel_values + + self.assertTrue(torch.allclose(encoded_slow, encoded_fast, atol=1e-1))