From 48292a9848fffc199c487c2b0b34f21c789acabb Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Wed, 12 Mar 2025 20:28:27 -0400 Subject: [PATCH] Refactor siglip2 fast image processor (#36406) * refactor siglip2 fast image processor, add unused_kwargs in base fast image processor * nits * change unused_kwargs default to None * update siglip2 fast image proc --- .../image_processing_utils_fast.py | 16 ++ .../siglip2/image_processing_siglip2_fast.py | 269 +++++------------- 2 files changed, 88 insertions(+), 197 deletions(-) diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index a87db33704..7d3065fda4 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -265,12 +265,14 @@ class BaseImageProcessorFast(BaseImageProcessor): device = None model_input_names = ["pixel_values"] valid_kwargs = DefaultFastImageProcessorKwargs + unused_kwargs = None def __init__( self, **kwargs: Unpack[DefaultFastImageProcessorKwargs], ) -> None: super().__init__(**kwargs) + kwargs = self.filter_out_unused_kwargs(kwargs) size = kwargs.pop("size", self.size) self.size = ( get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square)) @@ -438,6 +440,19 @@ class BaseImageProcessorFast(BaseImageProcessor): """ return convert_to_rgb(image) + def filter_out_unused_kwargs(self, kwargs: dict): + """ + Filter out the unused kwargs from the kwargs dictionary. + """ + if self.unused_kwargs is None: + return kwargs + + for kwarg_name in self.unused_kwargs: + if kwarg_name in kwargs: + logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.") + kwargs.pop(kwarg_name) + return kwargs + def _prepare_images_structure( self, images: ImageInput, @@ -634,6 +649,7 @@ class BaseImageProcessorFast(BaseImageProcessor): image_mean: Optional[Union[float, List[float]]], image_std: Optional[Union[float, List[float]]], return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: # Group images by size for batched resizing grouped_images, grouped_images_index = group_images_by_shape(images) diff --git a/src/transformers/models/siglip2/image_processing_siglip2_fast.py b/src/transformers/models/siglip2/image_processing_siglip2_fast.py index 3cb2015e36..26948a507f 100644 --- a/src/transformers/models/siglip2/image_processing_siglip2_fast.py +++ b/src/transformers/models/siglip2/image_processing_siglip2_fast.py @@ -14,78 +14,46 @@ # limitations under the License. """Fast Image processor class for SigLIP2.""" -import math from functools import lru_cache from typing import List, Optional, Tuple, Union import torch from ...image_processing_utils import BatchFeature -from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + SizeDict, +) from ...image_utils import ( - ChannelDimension, ImageInput, PILImageResampling, - SizeDict, - TensorType, ) +from ...processing_utils import Unpack from ...utils import ( - filter_out_non_signature_kwargs, + TensorType, + add_start_docstrings, is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + logging, ) +from .image_processing_siglip2 import get_image_size_for_max_num_patches if is_torch_available(): import torch +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F -@lru_cache(maxsize=256) -# Copied from transformers.models.siglip2.image_processing_siglip2.get_image_size_for_max_num_patches -def get_image_size_for_max_num_patches( - image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5 -) -> Tuple[int, int]: - """ - Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch. - Args: - image_height (`int`): - Original image height. - image_width (`int`): - Original image width. - patch_size (`int`): - Patch size for processing. - max_num_patches (`int`): - Maximum number of patches. - eps (`float`): - Small threshold for binary search. - - Returns: - Tuple: (target_height, target_width) - """ - - def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int: - scaled_size = size * scale - scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size - scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch - return int(scaled_size) - - # Binary search for optimal scale - scale_min, scale_max = eps / 10, 100.0 - while (scale_max - scale_min) >= eps: - scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size) - target_width = get_scaled_image_size(scale, image_width, patch_size) - num_patches = (target_height / patch_size) * (target_width / patch_size) - - if num_patches <= max_num_patches: - scale_min = scale - else: - scale_max = scale - - scale = scale_min - target_height = get_scaled_image_size(scale, image_height, patch_size) - target_width = get_scaled_image_size(scale, image_width, patch_size) - return target_height, target_width +logger = logging.get_logger(__name__) def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor": @@ -118,164 +86,71 @@ def pad_along_first_dim( return tensor, mask -class Siglip2ImageProcessorFast(BaseImageProcessorFast): - r""" - Constructs a fast SigLIP2 image processor. +class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + patch_size: Optional[int] + max_num_patches: Optional[int] - Args: - do_resize (`bool`, *optional*, defaults to `True`): - Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`. - Can be overridden by `do_resize` in the `preprocess` method. - resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): - Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. - do_rescale (`bool`, *optional*, defaults to `True`): - Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in - the `preprocess` method. - rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): - Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` - method. - do_normalize (`bool`, *optional*, defaults to `True`): - Whether to normalize the image by the specified mean and standard deviation. Can be overridden by - `do_normalize` in the `preprocess` method. - image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): - Mean to use if normalizing the image. This is a float or list of floats the length of the number of - channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. - image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): - Standard deviation to use if normalizing the image. This is a float or list of floats the length of the - number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. - Can be overridden by the `image_std` parameter in the `preprocess` method. - do_convert_rgb (`bool`, *optional*, defaults to `True`): - Whether to convert the image to RGB. + +@add_start_docstrings( + r"Constructs a fast Siglip2 image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch the image will be split to. max_num_patches (`int`, *optional*, defaults to 256): The image will be resized to have at most this number of patches, and then padded in "patch" dimension to match this number exactly. - """ + """, +) +class Siglip2ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + do_resize = True + do_rescale = True + do_normalize = True + patch_size = 16 + max_num_patches = 256 + valid_kwargs = Siglip2FastImageProcessorKwargs + unused_kwargs = ["size", "do_center_crop", "crop_size"] - def __init__( - self, - do_resize: bool = True, - resample: PILImageResampling = PILImageResampling.BILINEAR, - do_rescale: bool = True, - rescale_factor: float = 1 / 255, - do_normalize: bool = True, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - do_convert_rgb: Optional[bool] = None, - patch_size: int = 16, - max_num_patches: int = 256, - **kwargs, - ): + def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]): super().__init__(**kwargs) - image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] - image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] + @lru_cache(maxsize=10) + def _prepare_process_arguments(self, **kwargs) -> tuple: + # Remove do_resize from kwargs to not raise an error as size is None + kwargs.pop("do_resize", None) + return super()._prepare_process_arguments(**kwargs) - self.do_resize = do_resize - self.resample = resample - self.do_rescale = do_rescale - self.rescale_factor = rescale_factor - self.do_normalize = do_normalize - self.image_mean = image_mean - self.image_std = image_std - self.do_convert_rgb = do_convert_rgb - self.patch_size = patch_size - self.max_num_patches = max_num_patches + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The size (resolution) of each patch the image will be split to. + max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): + The image will be resized to have at most this number of patches, + and then padded in "patch" dimension to match this number exactly. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) - @filter_out_non_signature_kwargs() - def preprocess( + def _preprocess( self, - images: ImageInput, - do_resize: Optional[bool] = None, - resample: Optional[PILImageResampling] = None, - do_rescale: Optional[bool] = None, - rescale_factor: Optional[float] = None, - do_normalize: Optional[bool] = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - do_convert_rgb: Optional[bool] = None, - patch_size: Optional[int] = None, - max_num_patches: Optional[int] = None, - device: Union["torch.device", str] = "cpu", + images: List["torch.Tensor"], + do_resize: bool, + patch_size: int, + max_num_patches: int, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, ) -> BatchFeature: - """ - Preprocess an image or batch of images. - - Args: - images (`ImageInput`): - Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If - passing in images with pixel values between 0 and 1, set `do_rescale=False`. - do_resize (`bool`, *optional*, defaults to `self.do_resize`): - Whether to resize the image. - size (`Dict[str, int]`, *optional*, defaults to `self.size`): - Size of the image after resizing. - resample (`int`, *optional*, defaults to `self.resample`): - Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only - has an effect if `do_resize` is set to `True`. - do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): - Whether to rescale the image. - rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): - Rescale factor to rescale the image by if `do_rescale` is set to `True`. - do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): - Whether to normalize the image. - image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): - Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. - image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): - Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to - `True`. - return_tensors (`str` or `TensorType`, *optional*): - The type of tensors to return. Can be one of: - - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - input_data_format (`ChannelDimension` or `str`, *optional*): - The channel dimension format for the input image. If unset, the channel dimension format is inferred - from the input image. Can be one of: - - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): - Whether to convert the image to RGB. - patch_size (`int`, *optional*, defaults to `self.patch_size`): - Patch size for processing, same as the patch size used in the model. - max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): - Maximum number of patches per image, the image will be resized to have at most this number of patches. - """ - do_resize = do_resize if do_resize is not None else self.do_resize - resample = resample if resample is not None else self.resample - do_rescale = do_rescale if do_rescale is not None else self.do_rescale - rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor - do_normalize = do_normalize if do_normalize is not None else self.do_normalize - image_mean = image_mean if image_mean is not None else self.image_mean - image_std = image_std if image_std is not None else self.image_std - do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - patch_size = patch_size if patch_size is not None else self.patch_size - max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches - - image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean - image_std = tuple(image_std) if isinstance(image_std, list) else image_std - - image_mean, image_std, interpolation = self._prepare_process_arguments( - do_normalize=do_normalize, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - image_mean=image_mean, - image_std=image_std, - resample=resample, - ) - - images = self._prepare_input_images( - images=images, - do_convert_rgb=do_convert_rgb, - input_data_format=input_data_format, - device=device, - ) - pixel_masks = [] pixel_values = [] spatial_shapes = []