Refactor siglip2 fast image processor (#36406)
* refactor siglip2 fast image processor, add unused_kwargs in base fast image processor * nits * change unused_kwargs default to None * update siglip2 fast image proc
This commit is contained in:
@@ -265,12 +265,14 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
device = None
|
||||
model_input_names = ["pixel_values"]
|
||||
valid_kwargs = DefaultFastImageProcessorKwargs
|
||||
unused_kwargs = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
kwargs = self.filter_out_unused_kwargs(kwargs)
|
||||
size = kwargs.pop("size", self.size)
|
||||
self.size = (
|
||||
get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
|
||||
@@ -438,6 +440,19 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
"""
|
||||
return convert_to_rgb(image)
|
||||
|
||||
def filter_out_unused_kwargs(self, kwargs: dict):
|
||||
"""
|
||||
Filter out the unused kwargs from the kwargs dictionary.
|
||||
"""
|
||||
if self.unused_kwargs is None:
|
||||
return kwargs
|
||||
|
||||
for kwarg_name in self.unused_kwargs:
|
||||
if kwarg_name in kwargs:
|
||||
logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.")
|
||||
kwargs.pop(kwarg_name)
|
||||
return kwargs
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
@@ -634,6 +649,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
|
||||
@@ -14,78 +14,46 @@
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for SigLIP2."""
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast
|
||||
from ...image_processing_utils_fast import (
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
SizeDict,
|
||||
)
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
TensorType,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
filter_out_non_signature_kwargs,
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
logging,
|
||||
)
|
||||
from .image_processing_siglip2 import get_image_size_for_max_num_patches
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
# Copied from transformers.models.siglip2.image_processing_siglip2.get_image_size_for_max_num_patches
|
||||
def get_image_size_for_max_num_patches(
|
||||
image_height: int, image_width: int, patch_size: int, max_num_patches: int, eps: float = 1e-5
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determine image size based on max number of patches, ensure dimensions are divisible by patch size and image is at least 1 patch.
|
||||
|
||||
Args:
|
||||
image_height (`int`):
|
||||
Original image height.
|
||||
image_width (`int`):
|
||||
Original image width.
|
||||
patch_size (`int`):
|
||||
Patch size for processing.
|
||||
max_num_patches (`int`):
|
||||
Maximum number of patches.
|
||||
eps (`float`):
|
||||
Small threshold for binary search.
|
||||
|
||||
Returns:
|
||||
Tuple: (target_height, target_width)
|
||||
"""
|
||||
|
||||
def get_scaled_image_size(scale: float, size: int, patch_size: int) -> int:
|
||||
scaled_size = size * scale
|
||||
scaled_size = math.ceil(scaled_size / patch_size) * patch_size # make divisible by patch_size
|
||||
scaled_size = max(patch_size, scaled_size) # ensure at least 1 patch
|
||||
return int(scaled_size)
|
||||
|
||||
# Binary search for optimal scale
|
||||
scale_min, scale_max = eps / 10, 100.0
|
||||
while (scale_max - scale_min) >= eps:
|
||||
scale = (scale_min + scale_max) / 2
|
||||
target_height = get_scaled_image_size(scale, image_height, patch_size)
|
||||
target_width = get_scaled_image_size(scale, image_width, patch_size)
|
||||
num_patches = (target_height / patch_size) * (target_width / patch_size)
|
||||
|
||||
if num_patches <= max_num_patches:
|
||||
scale_min = scale
|
||||
else:
|
||||
scale_max = scale
|
||||
|
||||
scale = scale_min
|
||||
target_height = get_scaled_image_size(scale, image_height, patch_size)
|
||||
target_width = get_scaled_image_size(scale, image_width, patch_size)
|
||||
return target_height, target_width
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor":
|
||||
@@ -118,164 +86,71 @@ def pad_along_first_dim(
|
||||
return tensor, mask
|
||||
|
||||
|
||||
class Siglip2ImageProcessorFast(BaseImageProcessorFast):
|
||||
r"""
|
||||
Constructs a fast SigLIP2 image processor.
|
||||
class Siglip2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
patch_size: Optional[int]
|
||||
max_num_patches: Optional[int]
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's dimensions to fit `max_num_patches` according to given `patch_size`.
|
||||
Can be overridden by `do_resize` in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
|
||||
`do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
|
||||
@add_start_docstrings(
|
||||
r"Constructs a fast Siglip2 image processor.",
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
||||
"""
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
The size (resolution) of each patch the image will be split to.
|
||||
max_num_patches (`int`, *optional*, defaults to 256):
|
||||
The image will be resized to have at most this number of patches,
|
||||
and then padded in "patch" dimension to match this number exactly.
|
||||
"""
|
||||
""",
|
||||
)
|
||||
class Siglip2ImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
image_mean = [0.5, 0.5, 0.5]
|
||||
image_std = [0.5, 0.5, 0.5]
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
patch_size = 16
|
||||
max_num_patches = 256
|
||||
valid_kwargs = Siglip2FastImageProcessorKwargs
|
||||
unused_kwargs = ["size", "do_center_crop", "crop_size"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: float = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
patch_size: int = 16,
|
||||
max_num_patches: int = 256,
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
|
||||
image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
|
||||
@lru_cache(maxsize=10)
|
||||
def _prepare_process_arguments(self, **kwargs) -> tuple:
|
||||
# Remove do_resize from kwargs to not raise an error as size is None
|
||||
kwargs.pop("do_resize", None)
|
||||
return super()._prepare_process_arguments(**kwargs)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.patch_size = patch_size
|
||||
self.max_num_patches = max_num_patches
|
||||
@add_start_docstrings(
|
||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||
"""
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The size (resolution) of each patch the image will be split to.
|
||||
max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
|
||||
The image will be resized to have at most this number of patches,
|
||||
and then padded in "patch" dimension to match this number exactly.
|
||||
""",
|
||||
)
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def preprocess(
|
||||
def _preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
resample: Optional[PILImageResampling] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
max_num_patches: Optional[int] = None,
|
||||
device: Union["torch.device", str] = "cpu",
|
||||
images: List["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
patch_size: int,
|
||||
max_num_patches: int,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
Patch size for processing, same as the patch size used in the model.
|
||||
max_num_patches (`int`, *optional*, defaults to `self.max_num_patches`):
|
||||
Maximum number of patches per image, the image will be resized to have at most this number of patches.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
max_num_patches = max_num_patches if max_num_patches is not None else self.max_num_patches
|
||||
|
||||
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
|
||||
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
|
||||
|
||||
image_mean, image_std, interpolation = self._prepare_process_arguments(
|
||||
do_normalize=do_normalize,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
images = self._prepare_input_images(
|
||||
images=images,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
pixel_masks = []
|
||||
pixel_values = []
|
||||
spatial_shapes = []
|
||||
|
||||
Reference in New Issue
Block a user