Fix rescale normalize inconsistencies in fast image processors (#36388)

* fix fused rescale normalize inconsistencies

* fix siglip2 fast image processor

* refactor kwargs validation and fused nirmalize rescale

* cleanup kwargs handling in preprocess

* update new procs after refactor
This commit is contained in:
Yoni Gozlan
2025-03-12 23:18:34 -04:00
committed by GitHub
parent 48292a9848
commit 79254c9b61
9 changed files with 164 additions and 226 deletions

View File

@@ -40,8 +40,8 @@ from .image_utils import (
get_image_type, get_image_type,
infer_channel_dimension_format, infer_channel_dimension_format,
make_flat_list_of_images, make_flat_list_of_images,
validate_fast_preprocess_arguments,
validate_kwargs, validate_kwargs,
validate_preprocess_arguments,
) )
from .processing_utils import Unpack from .processing_utils import Unpack
from .utils import ( from .utils import (
@@ -72,6 +72,49 @@ if is_torchvision_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@lru_cache(maxsize=10)
def validate_fast_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[SizeDict] = None,
do_resize: Optional[bool] = None,
size: Optional[SizeDict] = None,
resample: Optional["PILImageResampling"] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
):
"""
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
"""
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisibility,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# Extra checks for ImageProcessorFast
if return_tensors is not None and return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor": def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
""" """
Squeezes a tensor, but only if the axis specified has dim 1. Squeezes a tensor, but only if the axis specified has dim 1.
@@ -380,6 +423,23 @@ class BaseImageProcessorFast(BaseImageProcessor):
""" """
return F.normalize(image, mean, std) return F.normalize(image, mean, std)
@lru_cache(maxsize=10)
def _fuse_mean_std_and_rescale_factor(
self,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
device: Optional["torch.device"] = None,
) -> tuple:
if do_rescale and do_normalize:
# Fused rescale and normalize
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
do_rescale = False
return image_mean, image_std, do_rescale
def rescale_and_normalize( def rescale_and_normalize(
self, self,
images: "torch.Tensor", images: "torch.Tensor",
@@ -392,12 +452,19 @@ class BaseImageProcessorFast(BaseImageProcessor):
""" """
Rescale and normalize images. Rescale and normalize images.
""" """
if do_rescale and do_normalize: image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
device=images.device,
)
# if/elif as we use fused rescale and normalize if both are set to True
if do_normalize:
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std) images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale: elif do_rescale:
images = images * rescale_factor images = self.rescale(images, rescale_factor)
elif do_normalize:
images = self.normalize(images, image_mean, image_std)
return images return images
@@ -527,25 +594,60 @@ class BaseImageProcessorFast(BaseImageProcessor):
return processed_images return processed_images
@lru_cache(maxsize=10) def _further_process_kwargs(
def _prepare_process_arguments(
self, self,
do_resize: bool = None, size: Optional[SizeDict] = None,
size: Dict[str, int] = None, crop_size: Optional[SizeDict] = None,
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None, default_to_square: Optional[bool] = None,
do_center_crop: bool = None,
crop_size: int = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, **kwargs,
device: Optional["torch.device"] = None, ) -> dict:
) -> tuple:
""" """
Prepare the arguments for the process method. Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if kwargs is None:
kwargs = {}
if size is not None:
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
if crop_size is not None:
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
if isinstance(image_mean, list):
image_mean = tuple(image_mean)
if isinstance(image_std, list):
image_std = tuple(image_std)
if data_format is None:
data_format = ChannelDimension.FIRST
kwargs["size"] = size
kwargs["crop_size"] = crop_size
kwargs["default_to_square"] = default_to_square
kwargs["image_mean"] = image_mean
kwargs["image_std"] = image_std
kwargs["data_format"] = data_format
return kwargs
def _validate_preprocess_kwargs(
self,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, tuple[float]]] = None,
image_std: Optional[Union[float, tuple[float]]] = None,
do_resize: Optional[bool] = None,
size: Optional[SizeDict] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[SizeDict] = None,
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
**kwargs,
):
"""
validate the kwargs for the preprocess method.
""" """
validate_fast_preprocess_arguments( validate_fast_preprocess_arguments(
do_rescale=do_rescale, do_rescale=do_rescale,
@@ -562,23 +664,8 @@ class BaseImageProcessorFast(BaseImageProcessor):
data_format=data_format, data_format=data_format,
) )
if do_rescale and do_normalize:
# Fused rescale and normalize
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
interpolation = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
return image_mean, image_std, interpolation
@add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS) @add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS)
def preprocess( def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
self,
images: ImageInput,
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
# Set default kwargs from self. This ensures that if a kwarg is not provided # Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None. # by the user, it gets its default value from the instance, or is set to None.
@@ -589,51 +676,28 @@ class BaseImageProcessorFast(BaseImageProcessor):
do_convert_rgb = kwargs.pop("do_convert_rgb") do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format") input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device") device = kwargs.pop("device")
# Prepare input images
images = self._prepare_input_images( images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
) )
# Pop kwargs that need further processing or won't be used in _preprocess # Update kwargs that need further processing before being validated
default_to_square = kwargs.pop("default_to_square") kwargs = self._further_process_kwargs(**kwargs)
size = kwargs.pop("size")
crop_size = kwargs.pop("crop_size") # Validate kwargs
image_mean = kwargs.pop("image_mean") self._validate_preprocess_kwargs(**kwargs)
image_std = kwargs.pop("image_std")
data_format = kwargs.pop("data_format") # torch resize uses interpolation instead of resample
resample = kwargs.pop("resample") resample = kwargs.pop("resample")
kwargs["interpolation"] = (
# Make hashable for cache pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
image_mean, image_std, interpolation = self._prepare_process_arguments(
size=size,
crop_size=crop_size,
resample=resample,
image_mean=image_mean,
image_std=image_std,
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
device=images[0].device,
do_resize=kwargs.get("do_resize"),
do_center_crop=kwargs.get("do_center_crop"),
do_rescale=kwargs.get("do_rescale"),
rescale_factor=kwargs.get("rescale_factor"),
do_normalize=kwargs.get("do_normalize"),
return_tensors=kwargs.get("return_tensors"),
) )
return self._preprocess( # Pop kwargs that are not needed in _preprocess
images=images, kwargs.pop("default_to_square")
size=size, kwargs.pop("data_format")
crop_size=crop_size,
interpolation=interpolation, return self._preprocess(images=images, **kwargs)
image_mean=image_mean,
image_std=image_std,
**kwargs,
)
def _preprocess( def _preprocess(
self, self,

View File

@@ -26,7 +26,6 @@ from packaging import version
from .utils import ( from .utils import (
ExplicitEnum, ExplicitEnum,
TensorType,
is_av_available, is_av_available,
is_cv2_available, is_cv2_available,
is_decord_available, is_decord_available,
@@ -942,48 +941,6 @@ def validate_preprocess_arguments(
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.") raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
def validate_fast_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
):
"""
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
"""
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisibility,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# Extra checks for ImageProcessorFast
if return_tensors is not None and return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
# In the future we can add a TF implementation here when we have TF models. # In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin: class ImageFeatureExtractionMixin:
""" """

View File

@@ -691,15 +691,8 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
target_size=resized_image.size()[-2:], target_size=resized_image.size()[-2:],
) )
image = resized_image image = resized_image
# Fused rescale and normalize
if do_rescale and do_normalize: image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
if do_convert_annotations and annotations is not None: if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))

View File

@@ -716,15 +716,8 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
target_size=resized_image.size()[-2:], target_size=resized_image.size()[-2:],
) )
image = resized_image image = resized_image
# Fused rescale and normalize
if do_rescale and do_normalize: image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
if do_convert_annotations and annotations is not None: if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))

View File

@@ -25,7 +25,6 @@ from ...image_processing_utils_fast import (
BaseImageProcessorFast, BaseImageProcessorFast,
BatchFeature, BatchFeature,
DefaultFastImageProcessorKwargs, DefaultFastImageProcessorKwargs,
get_size_dict,
group_images_by_shape, group_images_by_shape,
reorder_images, reorder_images,
) )
@@ -37,7 +36,6 @@ from ...image_utils import (
SizeDict, SizeDict,
get_image_size, get_image_size,
make_nested_list_of_images, make_nested_list_of_images,
validate_kwargs,
) )
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ( from ...utils import (
@@ -255,61 +253,7 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
images: ImageInput, images: ImageInput,
**kwargs: Unpack[Gemma3FastImageProcessorKwargs], **kwargs: Unpack[Gemma3FastImageProcessorKwargs],
) -> BatchFeature: ) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) return super().preprocess(images, **kwargs)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# Pop kwargs that need further processing or won't be used in _preprocess
default_to_square = kwargs.pop("default_to_square")
size = kwargs.pop("size")
crop_size = kwargs.pop("crop_size")
image_mean = kwargs.pop("image_mean")
image_std = kwargs.pop("image_std")
data_format = kwargs.pop("data_format")
resample = kwargs.pop("resample")
# Make hashable for cache
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
image_mean, image_std, interpolation = self._prepare_process_arguments(
size=size,
crop_size=crop_size,
resample=resample,
image_mean=image_mean,
image_std=image_std,
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
device=images[0][0].device,
do_resize=kwargs.get("do_resize"),
do_center_crop=kwargs.get("do_center_crop"),
do_rescale=kwargs.get("do_rescale"),
rescale_factor=kwargs.get("rescale_factor"),
do_normalize=kwargs.get("do_normalize"),
return_tensors=kwargs.get("return_tensors"),
)
return self._preprocess(
images=images,
size=size,
crop_size=crop_size,
interpolation=interpolation,
image_mean=image_mean,
image_std=image_std,
**kwargs,
)
def _preprocess( def _preprocess(
self, self,

View File

@@ -49,7 +49,6 @@ from ...utils import (
is_torch_available, is_torch_available,
is_torchvision_available, is_torchvision_available,
is_torchvision_v2_available, is_torchvision_v2_available,
is_vision_available,
logging, logging,
) )
from .image_processing_qwen2_vl import smart_resize from .image_processing_qwen2_vl import smart_resize
@@ -58,12 +57,13 @@ from .image_processing_qwen2_vl import smart_resize
if is_torch_available(): if is_torch_available():
import torch import torch
if is_vision_available():
pass if is_torchvision_available():
from ...image_utils import pil_torch_interpolation_mapping
if is_torchvision_v2_available(): if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
elif is_torchvision_available(): else:
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -311,19 +311,22 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
image_mean = tuple(image_mean) if image_mean is not None else None image_mean = tuple(image_mean) if image_mean is not None else None
image_std = tuple(image_std) if image_std is not None else None image_std = tuple(image_std) if image_std is not None else None
image_mean, image_std, interpolation = self._prepare_process_arguments( self._validate_preprocess_kwargs(
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale, do_rescale=do_rescale,
rescale_factor=rescale_factor, rescale_factor=rescale_factor,
do_normalize=do_normalize, do_normalize=do_normalize,
image_mean=image_mean, image_mean=image_mean,
image_std=image_std, image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
return_tensors=return_tensors, return_tensors=return_tensors,
data_format=data_format, data_format=data_format,
device=device,
) )
interpolation = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
if images is not None: if images is not None:
images = make_flat_list_of_images(images) images = make_flat_list_of_images(images)
if videos is not None: if videos is not None:

View File

@@ -486,15 +486,8 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
target_size=resized_image.size()[-2:], target_size=resized_image.size()[-2:],
) )
image = resized_image image = resized_image
# Fused rescale and normalize
if do_rescale and do_normalize: image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
if do_convert_annotations and annotations is not None: if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))

View File

@@ -266,15 +266,8 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
target_size=resized_image.size()[-2:], target_size=resized_image.size()[-2:],
) )
image = resized_image image = resized_image
# Fused rescale and normalize
if do_rescale and do_normalize: image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
if do_convert_annotations and annotations is not None: if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST)) annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Fast Image processor class for SigLIP2.""" """Fast Image processor class for SigLIP2."""
from functools import lru_cache
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@@ -117,11 +116,10 @@ class Siglip2ImageProcessorFast(BaseImageProcessorFast):
def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]): def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
super().__init__(**kwargs) super().__init__(**kwargs)
@lru_cache(maxsize=10) def _validate_preprocess_kwargs(self, **kwargs) -> tuple:
def _prepare_process_arguments(self, **kwargs) -> tuple:
# Remove do_resize from kwargs to not raise an error as size is None # Remove do_resize from kwargs to not raise an error as size is None
kwargs.pop("do_resize", None) kwargs.pop("do_resize", None)
return super()._prepare_process_arguments(**kwargs) return super()._validate_preprocess_kwargs(**kwargs)
@add_start_docstrings( @add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,