From 79254c9b61ac2e1a6e0f698c38b1654159095e2e Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Wed, 12 Mar 2025 23:18:34 -0400 Subject: [PATCH] Fix rescale normalize inconsistencies in fast image processors (#36388) * fix fused rescale normalize inconsistencies * fix siglip2 fast image processor * refactor kwargs validation and fused nirmalize rescale * cleanup kwargs handling in preprocess * update new procs after refactor --- .../image_processing_utils_fast.py | 212 ++++++++++++------ src/transformers/image_utils.py | 43 ---- .../image_processing_deformable_detr_fast.py | 11 +- .../models/detr/image_processing_detr_fast.py | 11 +- .../gemma3/image_processing_gemma3_fast.py | 58 +---- .../image_processing_qwen2_vl_fast.py | 27 ++- .../rt_detr/image_processing_rt_detr_fast.py | 11 +- .../models/rt_detr/modular_rt_detr.py | 11 +- .../siglip2/image_processing_siglip2_fast.py | 6 +- 9 files changed, 164 insertions(+), 226 deletions(-) diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index 7d3065fda4..0a201220b6 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -40,8 +40,8 @@ from .image_utils import ( get_image_type, infer_channel_dimension_format, make_flat_list_of_images, - validate_fast_preprocess_arguments, validate_kwargs, + validate_preprocess_arguments, ) from .processing_utils import Unpack from .utils import ( @@ -72,6 +72,49 @@ if is_torchvision_available(): logger = logging.get_logger(__name__) +@lru_cache(maxsize=10) +def validate_fast_preprocess_arguments( + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisibility: Optional[int] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[SizeDict] = None, + do_resize: Optional[bool] = None, + size: Optional[SizeDict] = None, + resample: Optional["PILImageResampling"] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, +): + """ + Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method. + Raises `ValueError` if arguments incompatibility is caught. + """ + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size_divisibility, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + # Extra checks for ImageProcessorFast + if return_tensors is not None and return_tensors != "pt": + raise ValueError("Only returning PyTorch tensors is currently supported.") + + if data_format != ChannelDimension.FIRST: + raise ValueError("Only channel first data format is currently supported.") + + def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor": """ Squeezes a tensor, but only if the axis specified has dim 1. @@ -380,6 +423,23 @@ class BaseImageProcessorFast(BaseImageProcessor): """ return F.normalize(image, mean, std) + @lru_cache(maxsize=10) + def _fuse_mean_std_and_rescale_factor( + self, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + device: Optional["torch.device"] = None, + ) -> tuple: + if do_rescale and do_normalize: + # Fused rescale and normalize + image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) + image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) + do_rescale = False + return image_mean, image_std, do_rescale + def rescale_and_normalize( self, images: "torch.Tensor", @@ -392,12 +452,19 @@ class BaseImageProcessorFast(BaseImageProcessor): """ Rescale and normalize images. """ - if do_rescale and do_normalize: + image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor( + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + device=images.device, + ) + # if/elif as we use fused rescale and normalize if both are set to True + if do_normalize: images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std) elif do_rescale: - images = images * rescale_factor - elif do_normalize: - images = self.normalize(images, image_mean, image_std) + images = self.rescale(images, rescale_factor) return images @@ -527,25 +594,60 @@ class BaseImageProcessorFast(BaseImageProcessor): return processed_images - @lru_cache(maxsize=10) - def _prepare_process_arguments( + def _further_process_kwargs( self, - do_resize: bool = None, - size: Dict[str, int] = None, - resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, - do_center_crop: bool = None, - crop_size: int = None, - do_rescale: bool = None, - rescale_factor: float = None, - do_normalize: bool = None, + size: Optional[SizeDict] = None, + crop_size: Optional[SizeDict] = None, + default_to_square: Optional[bool] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, - device: Optional["torch.device"] = None, - ) -> tuple: + data_format: Optional[ChannelDimension] = None, + **kwargs, + ) -> dict: """ - Prepare the arguments for the process method. + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if kwargs is None: + kwargs = {} + if size is not None: + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) + if crop_size is not None: + crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) + if isinstance(image_mean, list): + image_mean = tuple(image_mean) + if isinstance(image_std, list): + image_std = tuple(image_std) + if data_format is None: + data_format = ChannelDimension.FIRST + + kwargs["size"] = size + kwargs["crop_size"] = crop_size + kwargs["default_to_square"] = default_to_square + kwargs["image_mean"] = image_mean + kwargs["image_std"] = image_std + kwargs["data_format"] = data_format + + return kwargs + + def _validate_preprocess_kwargs( + self, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, tuple[float]]] = None, + image_std: Optional[Union[float, tuple[float]]] = None, + do_resize: Optional[bool] = None, + size: Optional[SizeDict] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[SizeDict] = None, + resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + **kwargs, + ): + """ + validate the kwargs for the preprocess method. """ validate_fast_preprocess_arguments( do_rescale=do_rescale, @@ -562,23 +664,8 @@ class BaseImageProcessorFast(BaseImageProcessor): data_format=data_format, ) - if do_rescale and do_normalize: - # Fused rescale and normalize - image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) - image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) - - interpolation = ( - pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample - ) - - return image_mean, image_std, interpolation - @add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS) - def preprocess( - self, - images: ImageInput, - **kwargs: Unpack[DefaultFastImageProcessorKwargs], - ) -> BatchFeature: + def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature: validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) # Set default kwargs from self. This ensures that if a kwarg is not provided # by the user, it gets its default value from the instance, or is set to None. @@ -589,51 +676,28 @@ class BaseImageProcessorFast(BaseImageProcessor): do_convert_rgb = kwargs.pop("do_convert_rgb") input_data_format = kwargs.pop("input_data_format") device = kwargs.pop("device") - + # Prepare input images images = self._prepare_input_images( images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device ) - # Pop kwargs that need further processing or won't be used in _preprocess - default_to_square = kwargs.pop("default_to_square") - size = kwargs.pop("size") - crop_size = kwargs.pop("crop_size") - image_mean = kwargs.pop("image_mean") - image_std = kwargs.pop("image_std") - data_format = kwargs.pop("data_format") + # Update kwargs that need further processing before being validated + kwargs = self._further_process_kwargs(**kwargs) + + # Validate kwargs + self._validate_preprocess_kwargs(**kwargs) + + # torch resize uses interpolation instead of resample resample = kwargs.pop("resample") - - # Make hashable for cache - size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None - crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None - image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean - image_std = tuple(image_std) if isinstance(image_std, list) else image_std - - image_mean, image_std, interpolation = self._prepare_process_arguments( - size=size, - crop_size=crop_size, - resample=resample, - image_mean=image_mean, - image_std=image_std, - data_format=data_format if data_format is not None else ChannelDimension.FIRST, - device=images[0].device, - do_resize=kwargs.get("do_resize"), - do_center_crop=kwargs.get("do_center_crop"), - do_rescale=kwargs.get("do_rescale"), - rescale_factor=kwargs.get("rescale_factor"), - do_normalize=kwargs.get("do_normalize"), - return_tensors=kwargs.get("return_tensors"), + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample ) - return self._preprocess( - images=images, - size=size, - crop_size=crop_size, - interpolation=interpolation, - image_mean=image_mean, - image_std=image_std, - **kwargs, - ) + # Pop kwargs that are not needed in _preprocess + kwargs.pop("default_to_square") + kwargs.pop("data_format") + + return self._preprocess(images=images, **kwargs) def _preprocess( self, diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index fec1e9dbc0..bde61e3803 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -26,7 +26,6 @@ from packaging import version from .utils import ( ExplicitEnum, - TensorType, is_av_available, is_cv2_available, is_decord_available, @@ -942,48 +941,6 @@ def validate_preprocess_arguments( raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.") -def validate_fast_preprocess_arguments( - do_rescale: Optional[bool] = None, - rescale_factor: Optional[float] = None, - do_normalize: Optional[bool] = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - do_pad: Optional[bool] = None, - size_divisibility: Optional[int] = None, - do_center_crop: Optional[bool] = None, - crop_size: Optional[Dict[str, int]] = None, - do_resize: Optional[bool] = None, - size: Optional[Dict[str, int]] = None, - resample: Optional["PILImageResampling"] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, -): - """ - Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method. - Raises `ValueError` if arguments incompatibility is caught. - """ - validate_preprocess_arguments( - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - do_pad=do_pad, - size_divisibility=size_divisibility, - do_center_crop=do_center_crop, - crop_size=crop_size, - do_resize=do_resize, - size=size, - resample=resample, - ) - # Extra checks for ImageProcessorFast - if return_tensors is not None and return_tensors != "pt": - raise ValueError("Only returning PyTorch tensors is currently supported.") - - if data_format != ChannelDimension.FIRST: - raise ValueError("Only channel first data format is currently supported.") - - # In the future we can add a TF implementation here when we have TF models. class ImageFeatureExtractionMixin: """ diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py index 850370e593..5d400f3d13 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py @@ -691,15 +691,8 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast): target_size=resized_image.size()[-2:], ) image = resized_image - - if do_rescale and do_normalize: - # fused rescale and normalize - image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std) - elif do_rescale: - image = image * rescale_factor - elif do_normalize: - image = F.normalize(image, image_mean, image_std) - + # Fused rescale and normalize + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) if do_convert_annotations and annotations is not None: annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) diff --git a/src/transformers/models/detr/image_processing_detr_fast.py b/src/transformers/models/detr/image_processing_detr_fast.py index 8d29a5796f..c6374db934 100644 --- a/src/transformers/models/detr/image_processing_detr_fast.py +++ b/src/transformers/models/detr/image_processing_detr_fast.py @@ -716,15 +716,8 @@ class DetrImageProcessorFast(BaseImageProcessorFast): target_size=resized_image.size()[-2:], ) image = resized_image - - if do_rescale and do_normalize: - # fused rescale and normalize - image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std) - elif do_rescale: - image = image * rescale_factor - elif do_normalize: - image = F.normalize(image, image_mean, image_std) - + # Fused rescale and normalize + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) if do_convert_annotations and annotations is not None: annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) diff --git a/src/transformers/models/gemma3/image_processing_gemma3_fast.py b/src/transformers/models/gemma3/image_processing_gemma3_fast.py index 50dfcb920f..6af4985dfd 100644 --- a/src/transformers/models/gemma3/image_processing_gemma3_fast.py +++ b/src/transformers/models/gemma3/image_processing_gemma3_fast.py @@ -25,7 +25,6 @@ from ...image_processing_utils_fast import ( BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs, - get_size_dict, group_images_by_shape, reorder_images, ) @@ -37,7 +36,6 @@ from ...image_utils import ( SizeDict, get_image_size, make_nested_list_of_images, - validate_kwargs, ) from ...processing_utils import Unpack from ...utils import ( @@ -255,61 +253,7 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast): images: ImageInput, **kwargs: Unpack[Gemma3FastImageProcessorKwargs], ) -> BatchFeature: - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) - # Set default kwargs from self. This ensures that if a kwarg is not provided - # by the user, it gets its default value from the instance, or is set to None. - for kwarg_name in self.valid_kwargs.__annotations__: - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) - - # Extract parameters that are only used for preparing the input images - do_convert_rgb = kwargs.pop("do_convert_rgb") - input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") - - images = self._prepare_input_images( - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device - ) - - # Pop kwargs that need further processing or won't be used in _preprocess - default_to_square = kwargs.pop("default_to_square") - size = kwargs.pop("size") - crop_size = kwargs.pop("crop_size") - image_mean = kwargs.pop("image_mean") - image_std = kwargs.pop("image_std") - data_format = kwargs.pop("data_format") - resample = kwargs.pop("resample") - - # Make hashable for cache - size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None - crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None - image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean - image_std = tuple(image_std) if isinstance(image_std, list) else image_std - - image_mean, image_std, interpolation = self._prepare_process_arguments( - size=size, - crop_size=crop_size, - resample=resample, - image_mean=image_mean, - image_std=image_std, - data_format=data_format if data_format is not None else ChannelDimension.FIRST, - device=images[0][0].device, - do_resize=kwargs.get("do_resize"), - do_center_crop=kwargs.get("do_center_crop"), - do_rescale=kwargs.get("do_rescale"), - rescale_factor=kwargs.get("rescale_factor"), - do_normalize=kwargs.get("do_normalize"), - return_tensors=kwargs.get("return_tensors"), - ) - - return self._preprocess( - images=images, - size=size, - crop_size=crop_size, - interpolation=interpolation, - image_mean=image_mean, - image_std=image_std, - **kwargs, - ) + return super().preprocess(images, **kwargs) def _preprocess( self, diff --git a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py index b54f86bba6..20b7543671 100644 --- a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +++ b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py @@ -49,7 +49,6 @@ from ...utils import ( is_torch_available, is_torchvision_available, is_torchvision_v2_available, - is_vision_available, logging, ) from .image_processing_qwen2_vl import smart_resize @@ -58,13 +57,14 @@ from .image_processing_qwen2_vl import smart_resize if is_torch_available(): import torch -if is_vision_available(): - pass -if is_torchvision_v2_available(): - from torchvision.transforms.v2 import functional as F -elif is_torchvision_available(): - from torchvision.transforms import functional as F +if is_torchvision_available(): + from ...image_utils import pil_torch_interpolation_mapping + + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F logger = logging.get_logger(__name__) @@ -311,19 +311,22 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast): image_mean = tuple(image_mean) if image_mean is not None else None image_std = tuple(image_std) if image_std is not None else None - image_mean, image_std, interpolation = self._prepare_process_arguments( - do_resize=do_resize, - size=size, - resample=resample, + self._validate_preprocess_kwargs( do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, return_tensors=return_tensors, data_format=data_format, - device=device, ) + interpolation = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + if images is not None: images = make_flat_list_of_images(images) if videos is not None: diff --git a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py index bd34843645..8e96a13b20 100644 --- a/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py +++ b/src/transformers/models/rt_detr/image_processing_rt_detr_fast.py @@ -486,15 +486,8 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast): target_size=resized_image.size()[-2:], ) image = resized_image - - if do_rescale and do_normalize: - # fused rescale and normalize - image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std) - elif do_rescale: - image = image * rescale_factor - elif do_normalize: - image = F.normalize(image, image_mean, image_std) - + # Fused rescale and normalize + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) if do_convert_annotations and annotations is not None: annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) diff --git a/src/transformers/models/rt_detr/modular_rt_detr.py b/src/transformers/models/rt_detr/modular_rt_detr.py index e1ee97b4da..74a849dd4a 100644 --- a/src/transformers/models/rt_detr/modular_rt_detr.py +++ b/src/transformers/models/rt_detr/modular_rt_detr.py @@ -266,15 +266,8 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast): target_size=resized_image.size()[-2:], ) image = resized_image - - if do_rescale and do_normalize: - # fused rescale and normalize - image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std) - elif do_rescale: - image = image * rescale_factor - elif do_normalize: - image = F.normalize(image, image_mean, image_std) - + # Fused rescale and normalize + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) if do_convert_annotations and annotations is not None: annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) diff --git a/src/transformers/models/siglip2/image_processing_siglip2_fast.py b/src/transformers/models/siglip2/image_processing_siglip2_fast.py index 26948a507f..dcd7cef629 100644 --- a/src/transformers/models/siglip2/image_processing_siglip2_fast.py +++ b/src/transformers/models/siglip2/image_processing_siglip2_fast.py @@ -14,7 +14,6 @@ # limitations under the License. """Fast Image processor class for SigLIP2.""" -from functools import lru_cache from typing import List, Optional, Tuple, Union import torch @@ -117,11 +116,10 @@ class Siglip2ImageProcessorFast(BaseImageProcessorFast): def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]): super().__init__(**kwargs) - @lru_cache(maxsize=10) - def _prepare_process_arguments(self, **kwargs) -> tuple: + def _validate_preprocess_kwargs(self, **kwargs) -> tuple: # Remove do_resize from kwargs to not raise an error as size is None kwargs.pop("do_resize", None) - return super()._prepare_process_arguments(**kwargs) + return super()._validate_preprocess_kwargs(**kwargs) @add_start_docstrings( BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,