Refactor siglip2 fast image processor (#36406)

* refactor siglip2 fast image processor, add unused_kwargs in base fast image processor

* nits

* change unused_kwargs default to None

* update siglip2 fast image proc
This commit is contained in:
Yoni Gozlan
2025-03-12 20:28:27 -04:00
committed by GitHub
parent ea219ed164
commit 48292a9848
2 changed files with 88 additions and 197 deletions

View File

@@ -265,12 +265,14 @@ class BaseImageProcessorFast(BaseImageProcessor):
device = None
model_input_names = ["pixel_values"]
valid_kwargs = DefaultFastImageProcessorKwargs
unused_kwargs = None
def __init__(
self,
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
) -> None:
super().__init__(**kwargs)
kwargs = self.filter_out_unused_kwargs(kwargs)
size = kwargs.pop("size", self.size)
self.size = (
get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
@@ -438,6 +440,19 @@ class BaseImageProcessorFast(BaseImageProcessor):
"""
return convert_to_rgb(image)
def filter_out_unused_kwargs(self, kwargs: dict):
"""
Filter out the unused kwargs from the kwargs dictionary.
"""
if self.unused_kwargs is None:
return kwargs
for kwarg_name in self.unused_kwargs:
if kwarg_name in kwargs:
logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.")
kwargs.pop(kwarg_name)
return kwargs
def _prepare_images_structure(
self,
images: ImageInput,
@@ -634,6 +649,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)

View File

@@ -14,78 +14,46 @@
# limitations under the License.
"""Fast Image processor class for SigLIP2."""
import math
from functools import lru_cache
from typing import List, Optional, Tuple, Union
import torch
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
SizeDict,
)
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
TensorType,
)
from ...processing_utils import Unpack
from ...utils import (
filter_out_non_signature_kwargs,
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
logging,
)
from .image_processing_siglip2 import get_image_size_for_max_num_patches
if is_torch_available():
import torch
@lru_cache(maxsize=256)
# Copied from transformers.models.siglip2.image_processing_siglip2.get_image_size_for_max_num_patches
def get_image_size_for_max_num_patches(
image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5
) -> Tuple[int, int]:
"""
Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch.
Args:
image_height (`int`):
Original image height.
image_width (`int`):
Original image width.
patch_size (`int`):
Patch size for processing.
max_num_patches (`int`):
Maximum number of patches.
eps (`float`):
Small threshold for binary search.
Returns:
Tuple: (target_height, target_width)
"""
def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int:
scaled_size = size * scale
scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size
scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch
return int(scaled_size)
# Binary search for optimal scale
scale_min, scale_max = eps / 10, 100.0
while (scale_max - scale_min) >= eps:
scale = (scale_min + scale_max) / 2
target_height = get_scaled_image_size(scale, image_height, patch_size)
target_width = get_scaled_image_size(scale, image_width, patch_size)
num_patches = (target_height / patch_size) * (target_width / patch_size)
if num_patches <= max_num_patches:
scale_min = scale
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
scale_max = scale
from torchvision.transforms import functional as F
scale = scale_min
target_height = get_scaled_image_size(scale, image_height, patch_size)
target_width = get_scaled_image_size(scale, image_width, patch_size)
return target_height, target_width
logger = logging.get_logger(__name__)
def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor":
@@ -118,164 +86,71 @@ def pad_along_first_dim(
return tensor, mask
class Siglip2ImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a fast SigLIP2 image processor.
class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
patch_size: Optional[int]
max_num_patches: Optional[int]
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`.
Can be overridden by `do_resize` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
`do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
@add_start_docstrings(
r"Constructs a fast Siglip2 image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch the image will be split to.
max_num_patches (`int`, *optional*, defaults to 256):
The image will be resized to have at most this number of patches,
and then padded in "patch" dimension to match this number exactly.
"""
""",
)
class Siglip2ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
do_resize = True
do_rescale = True
do_normalize = True
patch_size = 16
max_num_patches = 256
valid_kwargs = Siglip2FastImageProcessorKwargs
unused_kwargs = ["size", "do_center_crop", "crop_size"]
def __init__(
self,
do_resize: bool = True,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
patch_size: int = 16,
max_num_patches: int = 256,
**kwargs,
):
def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
super().__init__(**kwargs)
image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
@lru_cache(maxsize=10)
def _prepare_process_arguments(self, **kwargs) -> tuple:
# Remove do_resize from kwargs to not raise an error as size is None
kwargs.pop("do_resize", None)
return super()._prepare_process_arguments(**kwargs)
self.do_resize = do_resize
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self.patch_size = patch_size
self.max_num_patches = max_num_patches
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
resample: Optional[PILImageResampling] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: Optional[bool] = None,
patch_size: Optional[int] = None,
max_num_patches: Optional[int] = None,
device: Union["torch.device", str] = "cpu",
) -> BatchFeature:
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
Patch size for processing, same as the patch size used in the model.
The size (resolution) of each patch the image will be split to.
max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
Maximum number of patches per image, the image will be resized to have at most this number of patches.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
patch_size = patch_size if patch_size is not None else self.patch_size
max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
image_mean, image_std, interpolation = self._prepare_process_arguments(
do_normalize=do_normalize,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
resample=resample,
)
images = self._prepare_input_images(
images=images,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
The image will be resized to have at most this number of patches,
and then padded in "patch" dimension to match this number exactly.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
patch_size: int,
max_num_patches: int,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
pixel_masks = []
pixel_values = []
spatial_shapes = []