From dd7dc4a4a2281c4a3eda1247fc05e34149a55786 Mon Sep 17 00:00:00 2001 From: farrosalferro <127369839+farrosalferro@users.noreply.github.com> Date: Sat, 28 Jun 2025 00:26:57 +0900 Subject: [PATCH] Add Fast Image Processor for Chameleon (#37140) * Add Fast Image Processor for Chameleon * add warning to resize and move blend_rgba to convert_to_rgb * Remove unrelated files * Update image_processing_chameleon_fast to use auto_docstring * fix equivalence test --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Co-authored-by: yonigozlan --- docs/source/en/model_doc/chameleon.md | 5 + .../models/auto/image_processing_auto.py | 2 +- src/transformers/models/chameleon/__init__.py | 1 + .../image_processing_chameleon_fast.py | 124 ++++++++++++++ .../test_image_processing_chameleon.py | 156 ++++++++++-------- 5 files changed, 216 insertions(+), 72 deletions(-) create mode 100644 src/transformers/models/chameleon/image_processing_chameleon_fast.py diff --git a/docs/source/en/model_doc/chameleon.md b/docs/source/en/model_doc/chameleon.md index e7c04811de..b0265b1b72 100644 --- a/docs/source/en/model_doc/chameleon.md +++ b/docs/source/en/model_doc/chameleon.md @@ -191,6 +191,11 @@ model = ChameleonForConditionalGeneration.from_pretrained( [[autodoc]] ChameleonImageProcessor - preprocess +## ChameleonImageProcessorFast + +[[autodoc]] ChameleonImageProcessorFast + - preprocess + ## ChameleonVQVAE [[autodoc]] ChameleonVQVAE diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index b8ce8c7280..6466645607 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -63,7 +63,7 @@ else: ("blip", ("BlipImageProcessor", "BlipImageProcessorFast")), ("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")), ("bridgetower", ("BridgeTowerImageProcessor", "BridgeTowerImageProcessorFast")), - ("chameleon", ("ChameleonImageProcessor",)), + ("chameleon", ("ChameleonImageProcessor", "ChameleonImageProcessorFast")), ("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")), ("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")), diff --git a/src/transformers/models/chameleon/__init__.py b/src/transformers/models/chameleon/__init__.py index 4332161036..6ad11a90a2 100644 --- a/src/transformers/models/chameleon/__init__.py +++ b/src/transformers/models/chameleon/__init__.py @@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure if TYPE_CHECKING: from .configuration_chameleon import * from .image_processing_chameleon import * + from .image_processing_chameleon_fast import * from .modeling_chameleon import * from .processing_chameleon import * else: diff --git a/src/transformers/models/chameleon/image_processing_chameleon_fast.py b/src/transformers/models/chameleon/image_processing_chameleon_fast.py new file mode 100644 index 0000000000..dea89a0d16 --- /dev/null +++ b/src/transformers/models/chameleon/image_processing_chameleon_fast.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2025 Meta Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Chameleon.""" + +import numpy as np + +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import ImageInput, PILImageResampling, SizeDict +from ...utils import ( + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, + logging, +) + + +if is_vision_available(): + import PIL +if is_torch_available(): + import torch +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + +logger = logging.get_logger(__name__) + + +@auto_docstring +class ChameleonImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.LANCZOS + image_mean = [1.0, 1.0, 1.0] + image_std = [1.0, 1.0, 1.0] + size = {"shortest_edge": 512} + default_to_square = False + crop_size = {"height": 512, "width": 512} + do_resize = True + do_center_crop = True + do_rescale = True + rescale_factor = 0.0078 + do_normalize = True + do_convert_rgb = True + + def convert_to_rgb(self, image: ImageInput) -> ImageInput: + """ + Convert image to RGB by blending the transparency layer if it's in RGBA format. + If image is not `PIL.Image`, it si simply returned without modifications. + + Args: + image (`ImageInput`): + Image to convert. + """ + + if not isinstance(image, PIL.Image.Image): + return image + elif image.mode == "RGB": + return image + + img_rgba = np.array(image.convert("RGBA")) + + # If there is no transparency layer, simple convert and return. + if not (img_rgba[:, :, 3] < 255).any(): + return image.convert("RGB") + + # There is a transparency layer, blend it with a white background. + # Calculate the alpha proportion for blending. + alpha = img_rgba[:, :, 3] / 255.0 + img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3] + return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB") + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + interpolation: "F.InterpolationMode" = None, + **kwargs, + ) -> "torch.Tensor": + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. + + Returns: + `torch.Tensor`: The resized image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if interpolation == F.InterpolationMode.LANCZOS: + logger.warning_once( + "You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. " + "BICUBIC resample will be used as an alternative. Please fall back to slow image processor if you " + "want full consistency with the original model." + ) + interpolation = F.InterpolationMode.BICUBIC + + return super().resize( + image=image, + size=size, + interpolation=interpolation, + **kwargs, + ) + + +__all__ = ["ChameleonImageProcessorFast"] diff --git a/tests/models/chameleon/test_image_processing_chameleon.py b/tests/models/chameleon/test_image_processing_chameleon.py index fcbd7b46d5..78576725f7 100644 --- a/tests/models/chameleon/test_image_processing_chameleon.py +++ b/tests/models/chameleon/test_image_processing_chameleon.py @@ -16,8 +16,9 @@ import unittest import numpy as np +from transformers.image_utils import PILImageResampling from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -30,6 +31,9 @@ if is_vision_available(): from transformers import ChameleonImageProcessor + if is_torchvision_available(): + from transformers import ChameleonImageProcessorFast + class ChameleonImageProcessingTester: def __init__( @@ -48,6 +52,7 @@ class ChameleonImageProcessingTester: image_mean=[1.0, 1.0, 1.0], image_std=[1.0, 1.0, 1.0], do_convert_rgb=True, + resample=PILImageResampling.BILINEAR, ): size = size if size is not None else {"shortest_edge": 18} crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} @@ -65,6 +70,7 @@ class ChameleonImageProcessingTester: self.image_mean = image_mean self.image_std = image_std self.do_convert_rgb = do_convert_rgb + self.resample = resample def prepare_image_processor_dict(self): return { @@ -76,6 +82,7 @@ class ChameleonImageProcessingTester: "image_mean": self.image_mean, "image_std": self.image_std, "do_convert_rgb": self.do_convert_rgb, + "resample": self.resample, } # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.expected_output_image_shape @@ -99,6 +106,7 @@ class ChameleonImageProcessingTester: @require_vision class ChameleonImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = ChameleonImageProcessor if is_vision_available() else None + fast_image_processing_class = ChameleonImageProcessorFast if is_torchvision_available() else None # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Chameleon def setUp(self): @@ -111,94 +119,100 @@ class ChameleonImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_center_crop")) - self.assertTrue(hasattr(image_processing, "center_crop")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 18}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 18}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) def test_call_pil(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) - for image in image_inputs: - self.assertIsInstance(image, Image.Image) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_call_numpy(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random numpy tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) - for image in image_inputs: - self.assertIsInstance(image, np.ndarray) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_call_pytorch(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) - # Test not batched input - encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values - expected_output_image_shape = (1, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - # Test batched - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) def test_nested_input(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) - # Test batched as a list of images - encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 3, 18, 18) - self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + # Test batched as a list of images + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) - # Test batched as a nested list of images, where each sublist is one batch - image_inputs_nested = [image_inputs[:3], image_inputs[3:]] - encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values - expected_output_image_shape = (7, 3, 18, 18) - self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) + # Test batched as a nested list of images, where each sublist is one batch + image_inputs_nested = [image_inputs[:3], image_inputs[3:]] + encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) - # Image processor should return same pixel values, independently of input format - self.assertTrue((encoded_images_nested == encoded_images).all()) + # Image processor should return same pixel values, independently of input format + self.assertTrue((encoded_images_nested == encoded_images).all())