Fix rescale normalize inconsistencies in fast image processors (#36388)

* fix fused rescale normalize inconsistencies

* fix siglip2 fast image processor

* refactor kwargs validation and fused nirmalize rescale

* cleanup kwargs handling in preprocess

* update new procs after refactor
This commit is contained in:
Yoni Gozlan
2025-03-12 23:18:34 -04:00
committed by GitHub
parent 48292a9848
commit 79254c9b61
9 changed files with 164 additions and 226 deletions

View File

@@ -40,8 +40,8 @@ from .image_utils import (
get_image_type,
infer_channel_dimension_format,
make_flat_list_of_images,
validate_fast_preprocess_arguments,
validate_kwargs,
validate_preprocess_arguments,
)
from .processing_utils import Unpack
from .utils import (
@@ -72,6 +72,49 @@ if is_torchvision_available():
logger = logging.get_logger(__name__)
@lru_cache(maxsize=10)
def validate_fast_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[SizeDict] = None,
do_resize: Optional[bool] = None,
size: Optional[SizeDict] = None,
resample: Optional["PILImageResampling"] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
):
"""
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
"""
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisibility,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# Extra checks for ImageProcessorFast
if return_tensors is not None and return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
"""
Squeezes a tensor, but only if the axis specified has dim 1.
@@ -380,6 +423,23 @@ class BaseImageProcessorFast(BaseImageProcessor):
"""
return F.normalize(image, mean, std)
@lru_cache(maxsize=10)
def _fuse_mean_std_and_rescale_factor(
self,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
device: Optional["torch.device"] = None,
) -> tuple:
if do_rescale and do_normalize:
# Fused rescale and normalize
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
do_rescale = False
return image_mean, image_std, do_rescale
def rescale_and_normalize(
self,
images: "torch.Tensor",
@@ -392,12 +452,19 @@ class BaseImageProcessorFast(BaseImageProcessor):
"""
Rescale and normalize images.
"""
if do_rescale and do_normalize:
image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
device=images.device,
)
# if/elif as we use fused rescale and normalize if both are set to True
if do_normalize:
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
images = images * rescale_factor
elif do_normalize:
images = self.normalize(images, image_mean, image_std)
images = self.rescale(images, rescale_factor)
return images
@@ -527,25 +594,60 @@ class BaseImageProcessorFast(BaseImageProcessor):
return processed_images
@lru_cache(maxsize=10)
def _prepare_process_arguments(
def _further_process_kwargs(
self,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
do_center_crop: bool = None,
crop_size: int = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
size: Optional[SizeDict] = None,
crop_size: Optional[SizeDict] = None,
default_to_square: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
device: Optional["torch.device"] = None,
) -> tuple:
data_format: Optional[ChannelDimension] = None,
**kwargs,
) -> dict:
"""
Prepare the arguments for the process method.
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if kwargs is None:
kwargs = {}
if size is not None:
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
if crop_size is not None:
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
if isinstance(image_mean, list):
image_mean = tuple(image_mean)
if isinstance(image_std, list):
image_std = tuple(image_std)
if data_format is None:
data_format = ChannelDimension.FIRST
kwargs["size"] = size
kwargs["crop_size"] = crop_size
kwargs["default_to_square"] = default_to_square
kwargs["image_mean"] = image_mean
kwargs["image_std"] = image_std
kwargs["data_format"] = data_format
return kwargs
def _validate_preprocess_kwargs(
self,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, tuple[float]]] = None,
image_std: Optional[Union[float, tuple[float]]] = None,
do_resize: Optional[bool] = None,
size: Optional[SizeDict] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[SizeDict] = None,
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
**kwargs,
):
"""
validate the kwargs for the preprocess method.
"""
validate_fast_preprocess_arguments(
do_rescale=do_rescale,
@@ -562,23 +664,8 @@ class BaseImageProcessorFast(BaseImageProcessor):
data_format=data_format,
)
if do_rescale and do_normalize:
# Fused rescale and normalize
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
interpolation = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
return image_mean, image_std, interpolation
@add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS)
def preprocess(
self,
images: ImageInput,
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
) -> BatchFeature:
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
@@ -589,51 +676,28 @@ class BaseImageProcessorFast(BaseImageProcessor):
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# Pop kwargs that need further processing or won't be used in _preprocess
default_to_square = kwargs.pop("default_to_square")
size = kwargs.pop("size")
crop_size = kwargs.pop("crop_size")
image_mean = kwargs.pop("image_mean")
image_std = kwargs.pop("image_std")
data_format = kwargs.pop("data_format")
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
resample = kwargs.pop("resample")
# Make hashable for cache
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
image_mean, image_std, interpolation = self._prepare_process_arguments(
size=size,
crop_size=crop_size,
resample=resample,
image_mean=image_mean,
image_std=image_std,
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
device=images[0].device,
do_resize=kwargs.get("do_resize"),
do_center_crop=kwargs.get("do_center_crop"),
do_rescale=kwargs.get("do_rescale"),
rescale_factor=kwargs.get("rescale_factor"),
do_normalize=kwargs.get("do_normalize"),
return_tensors=kwargs.get("return_tensors"),
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
return self._preprocess(
images=images,
size=size,
crop_size=crop_size,
interpolation=interpolation,
image_mean=image_mean,
image_std=image_std,
**kwargs,
)
# Pop kwargs that are not needed in _preprocess
kwargs.pop("default_to_square")
kwargs.pop("data_format")
return self._preprocess(images=images, **kwargs)
def _preprocess(
self,

View File

@@ -26,7 +26,6 @@ from packaging import version
from .utils import (
ExplicitEnum,
TensorType,
is_av_available,
is_cv2_available,
is_decord_available,
@@ -942,48 +941,6 @@ def validate_preprocess_arguments(
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
def validate_fast_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
):
"""
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
"""
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisibility,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# Extra checks for ImageProcessorFast
if return_tensors is not None and return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""

View File

@@ -691,15 +691,8 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
target_size=resized_image.size()[-2:],
)
image = resized_image
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
# Fused rescale and normalize
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))

View File

@@ -716,15 +716,8 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
target_size=resized_image.size()[-2:],
)
image = resized_image
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
# Fused rescale and normalize
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))

View File

@@ -25,7 +25,6 @@ from ...image_processing_utils_fast import (
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorKwargs,
get_size_dict,
group_images_by_shape,
reorder_images,
)
@@ -37,7 +36,6 @@ from ...image_utils import (
SizeDict,
get_image_size,
make_nested_list_of_images,
validate_kwargs,
)
from ...processing_utils import Unpack
from ...utils import (
@@ -255,61 +253,7 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
images: ImageInput,
**kwargs: Unpack[Gemma3FastImageProcessorKwargs],
) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# Pop kwargs that need further processing or won't be used in _preprocess
default_to_square = kwargs.pop("default_to_square")
size = kwargs.pop("size")
crop_size = kwargs.pop("crop_size")
image_mean = kwargs.pop("image_mean")
image_std = kwargs.pop("image_std")
data_format = kwargs.pop("data_format")
resample = kwargs.pop("resample")
# Make hashable for cache
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
image_mean, image_std, interpolation = self._prepare_process_arguments(
size=size,
crop_size=crop_size,
resample=resample,
image_mean=image_mean,
image_std=image_std,
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
device=images[0][0].device,
do_resize=kwargs.get("do_resize"),
do_center_crop=kwargs.get("do_center_crop"),
do_rescale=kwargs.get("do_rescale"),
rescale_factor=kwargs.get("rescale_factor"),
do_normalize=kwargs.get("do_normalize"),
return_tensors=kwargs.get("return_tensors"),
)
return self._preprocess(
images=images,
size=size,
crop_size=crop_size,
interpolation=interpolation,
image_mean=image_mean,
image_std=image_std,
**kwargs,
)
return super().preprocess(images, **kwargs)
def _preprocess(
self,

View File

@@ -49,7 +49,6 @@ from ...utils import (
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
logging,
)
from .image_processing_qwen2_vl import smart_resize
@@ -58,12 +57,13 @@ from .image_processing_qwen2_vl import smart_resize
if is_torch_available():
import torch
if is_vision_available():
pass
if is_torchvision_available():
from ...image_utils import pil_torch_interpolation_mapping
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
elif is_torchvision_available():
else:
from torchvision.transforms import functional as F
logger = logging.get_logger(__name__)
@@ -311,19 +311,22 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
image_mean = tuple(image_mean) if image_mean is not None else None
image_std = tuple(image_std) if image_std is not None else None
image_mean, image_std, interpolation = self._prepare_process_arguments(
do_resize=do_resize,
size=size,
resample=resample,
self._validate_preprocess_kwargs(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
return_tensors=return_tensors,
data_format=data_format,
device=device,
)
interpolation = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
if images is not None:
images = make_flat_list_of_images(images)
if videos is not None:

View File

@@ -486,15 +486,8 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
target_size=resized_image.size()[-2:],
)
image = resized_image
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
# Fused rescale and normalize
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))

View File

@@ -266,15 +266,8 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
target_size=resized_image.size()[-2:],
)
image = resized_image
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
# Fused rescale and normalize
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))

View File

@@ -14,7 +14,6 @@
# limitations under the License.
"""Fast Image processor class for SigLIP2."""
from functools import lru_cache
from typing import List, Optional, Tuple, Union
import torch
@@ -117,11 +116,10 @@ class Siglip2ImageProcessorFast(BaseImageProcessorFast):
def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
super().__init__(**kwargs)
@lru_cache(maxsize=10)
def _prepare_process_arguments(self, **kwargs) -> tuple:
def _validate_preprocess_kwargs(self, **kwargs) -> tuple:
# Remove do_resize from kwargs to not raise an error as size is None
kwargs.pop("do_resize", None)
return super()._prepare_process_arguments(**kwargs)
return super()._validate_preprocess_kwargs(**kwargs)
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,