Refactor siglip2 fast image processor (#36406)

* refactor siglip2 fast image processor, add unused_kwargs in base fast image processor

* nits

* change unused_kwargs default to None

* update siglip2 fast image proc
This commit is contained in:
Yoni Gozlan
2025-03-12 20:28:27 -04:00
committed by GitHub
parent ea219ed164
commit 48292a9848
2 changed files with 88 additions and 197 deletions

View File

@@ -265,12 +265,14 @@ class BaseImageProcessorFast(BaseImageProcessor):
device = None device = None
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
valid_kwargs = DefaultFastImageProcessorKwargs valid_kwargs = DefaultFastImageProcessorKwargs
unused_kwargs = None
def __init__( def __init__(
self, self,
**kwargs: Unpack[DefaultFastImageProcessorKwargs], **kwargs: Unpack[DefaultFastImageProcessorKwargs],
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
kwargs = self.filter_out_unused_kwargs(kwargs)
size = kwargs.pop("size", self.size) size = kwargs.pop("size", self.size)
self.size = ( self.size = (
get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square)) get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
@@ -438,6 +440,19 @@ class BaseImageProcessorFast(BaseImageProcessor):
""" """
return convert_to_rgb(image) return convert_to_rgb(image)
def filter_out_unused_kwargs(self, kwargs: dict):
"""
Filter out the unused kwargs from the kwargs dictionary.
"""
if self.unused_kwargs is None:
return kwargs
for kwarg_name in self.unused_kwargs:
if kwarg_name in kwargs:
logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.")
kwargs.pop(kwarg_name)
return kwargs
def _prepare_images_structure( def _prepare_images_structure(
self, self,
images: ImageInput, images: ImageInput,
@@ -634,6 +649,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]], image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]], image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]], return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature: ) -> BatchFeature:
# Group images by size for batched resizing # Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images) grouped_images, grouped_images_index = group_images_by_shape(images)

View File

@@ -14,78 +14,46 @@
# limitations under the License. # limitations under the License.
"""Fast Image processor class for SigLIP2.""" """Fast Image processor class for SigLIP2."""
import math
from functools import lru_cache from functools import lru_cache
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from ...image_processing_utils import BatchFeature from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import BaseImageProcessorFast from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
SizeDict,
)
from ...image_utils import ( from ...image_utils import (
ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
SizeDict,
TensorType,
) )
from ...processing_utils import Unpack
from ...utils import ( from ...utils import (
filter_out_non_signature_kwargs, TensorType,
add_start_docstrings,
is_torch_available, is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
logging,
) )
from .image_processing_siglip2 import get_image_size_for_max_num_patches
if is_torch_available(): if is_torch_available():
import torch import torch
if is_torchvision_available():
@lru_cache(maxsize=256) if is_torchvision_v2_available():
# Copied from transformers.models.siglip2.image_processing_siglip2.get_image_size_for_max_num_patches from torchvision.transforms.v2 import functional as F
def get_image_size_for_max_num_patches(
image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5
) -> Tuple[int, int]:
"""
Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch.
Args:
image_height (`int`):
Original image height.
image_width (`int`):
Original image width.
patch_size (`int`):
Patch size for processing.
max_num_patches (`int`):
Maximum number of patches.
eps (`float`):
Small threshold for binary search.
Returns:
Tuple: (target_height, target_width)
"""
def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int:
scaled_size = size * scale
scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size
scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch
return int(scaled_size)
# Binary search for optimal scale
scale_min, scale_max = eps / 10, 100.0
while (scale_max - scale_min) >= eps:
scale = (scale_min + scale_max) / 2
target_height = get_scaled_image_size(scale, image_height, patch_size)
target_width = get_scaled_image_size(scale, image_width, patch_size)
num_patches = (target_height / patch_size) * (target_width / patch_size)
if num_patches <= max_num_patches:
scale_min = scale
else: else:
scale_max = scale from torchvision.transforms import functional as F
scale = scale_min
target_height = get_scaled_image_size(scale, image_height, patch_size) logger = logging.get_logger(__name__)
target_width = get_scaled_image_size(scale, image_width, patch_size)
return target_height, target_width
def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor": def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor":
@@ -118,164 +86,71 @@ def pad_along_first_dim(
return tensor, mask return tensor, mask
class Siglip2ImageProcessorFast(BaseImageProcessorFast): class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
r""" patch_size: Optional[int]
Constructs a fast SigLIP2 image processor. max_num_patches: Optional[int]
Args:
do_resize (`bool`, *optional*, defaults to `True`): @add_start_docstrings(
Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`. r"Constructs a fast Siglip2 image processor.",
Can be overridden by `do_resize` in the `preprocess` method. BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): """
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
`do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
patch_size (`int`, *optional*, defaults to 16): patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch the image will be split to. The size (resolution) of each patch the image will be split to.
max_num_patches (`int`, *optional*, defaults to 256): max_num_patches (`int`, *optional*, defaults to 256):
The image will be resized to have at most this number of patches, The image will be resized to have at most this number of patches,
and then padded in "patch" dimension to match this number exactly. and then padded in "patch" dimension to match this number exactly.
""" """,
)
class Siglip2ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
do_resize = True
do_rescale = True
do_normalize = True
patch_size = 16
max_num_patches = 256
valid_kwargs = Siglip2FastImageProcessorKwargs
unused_kwargs = ["size", "do_center_crop", "crop_size"]
def __init__( def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
self,
do_resize: bool = True,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
patch_size: int = 16,
max_num_patches: int = 256,
**kwargs,
):
super().__init__(**kwargs) super().__init__(**kwargs)
image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] @lru_cache(maxsize=10)
image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] def _prepare_process_arguments(self, **kwargs) -> tuple:
# Remove do_resize from kwargs to not raise an error as size is None
kwargs.pop("do_resize", None)
return super()._prepare_process_arguments(**kwargs)
self.do_resize = do_resize @add_start_docstrings(
self.resample = resample BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self.patch_size = patch_size
self.max_num_patches = max_num_patches
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
resample: Optional[PILImageResampling] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: Optional[bool] = None,
patch_size: Optional[int] = None,
max_num_patches: Optional[int] = None,
device: Union["torch.device", str] = "cpu",
) -> BatchFeature:
""" """
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
patch_size (`int`, *optional*, defaults to `self.patch_size`): patch_size (`int`, *optional*, defaults to `self.patch_size`):
Patch size for processing, same as the patch size used in the model. The size (resolution) of each patch the image will be split to.
max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`): max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
Maximum number of patches per image, the image will be resized to have at most this number of patches. The image will be resized to have at most this number of patches,
""" and then padded in "patch" dimension to match this number exactly.
do_resize = do_resize if do_resize is not None else self.do_resize """,
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
patch_size = patch_size if patch_size is not None else self.patch_size
max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
image_mean, image_std, interpolation = self._prepare_process_arguments(
do_normalize=do_normalize,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
resample=resample,
)
images = self._prepare_input_images(
images=images,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
) )
def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
patch_size: int,
max_num_patches: int,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
pixel_masks = [] pixel_masks = []
pixel_values = [] pixel_values = []
spatial_shapes = [] spatial_shapes = []