Fix rescale normalize inconsistencies in fast image processors (#36388)

* fix fused rescale normalize inconsistencies

* fix siglip2 fast image processor

* refactor kwargs validation and fused nirmalize rescale

* cleanup kwargs handling in preprocess

* update new procs after refactor
This commit is contained in:
Yoni Gozlan
2025-03-12 23:18:34 -04:00
committed by GitHub
parent 48292a9848
commit 79254c9b61
9 changed files with 164 additions and 226 deletions

View File

@@ -40,8 +40,8 @@ from .image_utils import (
get_image_type,
infer_channel_dimension_format,
make_flat_list_of_images,
validate_fast_preprocess_arguments,
validate_kwargs,
validate_preprocess_arguments,
)
from .processing_utils import Unpack
from .utils import (
@@ -72,6 +72,49 @@ if is_torchvision_available():
logger = logging.get_logger(__name__)
@lru_cache(maxsize=10)
def validate_fast_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[SizeDict] = None,
do_resize: Optional[bool] = None,
size: Optional[SizeDict] = None,
resample: Optional["PILImageResampling"] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
):
"""
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
"""
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisibility,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# Extra checks for ImageProcessorFast
if return_tensors is not None and return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
"""
Squeezes a tensor, but only if the axis specified has dim 1.
@@ -380,6 +423,23 @@ class BaseImageProcessorFast(BaseImageProcessor):
"""
return F.normalize(image, mean, std)
@lru_cache(maxsize=10)
def _fuse_mean_std_and_rescale_factor(
self,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
device: Optional["torch.device"] = None,
) -> tuple:
if do_rescale and do_normalize:
# Fused rescale and normalize
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
do_rescale = False
return image_mean, image_std, do_rescale
def rescale_and_normalize(
self,
images: "torch.Tensor",
@@ -392,12 +452,19 @@ class BaseImageProcessorFast(BaseImageProcessor):
"""
Rescale and normalize images.
"""
if do_rescale and do_normalize:
image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
device=images.device,
)
# if/elif as we use fused rescale and normalize if both are set to True
if do_normalize:
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
images = images * rescale_factor
elif do_normalize:
images = self.normalize(images, image_mean, image_std)
images = self.rescale(images, rescale_factor)
return images
@@ -527,25 +594,60 @@ class BaseImageProcessorFast(BaseImageProcessor):
return processed_images
@lru_cache(maxsize=10)
def _prepare_process_arguments(
def _further_process_kwargs(
self,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
do_center_crop: bool = None,
crop_size: int = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
size: Optional[SizeDict] = None,
crop_size: Optional[SizeDict] = None,
default_to_square: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
device: Optional["torch.device"] = None,
) -> tuple:
data_format: Optional[ChannelDimension] = None,
**kwargs,
) -> dict:
"""
Prepare the arguments for the process method.
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if kwargs is None:
kwargs = {}
if size is not None:
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
if crop_size is not None:
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
if isinstance(image_mean, list):
image_mean = tuple(image_mean)
if isinstance(image_std, list):
image_std = tuple(image_std)
if data_format is None:
data_format = ChannelDimension.FIRST
kwargs["size"] = size
kwargs["crop_size"] = crop_size
kwargs["default_to_square"] = default_to_square
kwargs["image_mean"] = image_mean
kwargs["image_std"] = image_std
kwargs["data_format"] = data_format
return kwargs
def _validate_preprocess_kwargs(
self,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, tuple[float]]] = None,
image_std: Optional[Union[float, tuple[float]]] = None,
do_resize: Optional[bool] = None,
size: Optional[SizeDict] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[SizeDict] = None,
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
**kwargs,
):
"""
validate the kwargs for the preprocess method.
"""
validate_fast_preprocess_arguments(
do_rescale=do_rescale,
@@ -562,23 +664,8 @@ class BaseImageProcessorFast(BaseImageProcessor):
data_format=data_format,
)
if do_rescale and do_normalize:
# Fused rescale and normalize
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
interpolation = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
return image_mean, image_std, interpolation
@add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS)
def preprocess(
self,
images: ImageInput,
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
) -> BatchFeature:
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
@@ -589,51 +676,28 @@ class BaseImageProcessorFast(BaseImageProcessor):
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# Pop kwargs that need further processing or won't be used in _preprocess
default_to_square = kwargs.pop("default_to_square")
size = kwargs.pop("size")
crop_size = kwargs.pop("crop_size")
image_mean = kwargs.pop("image_mean")
image_std = kwargs.pop("image_std")
data_format = kwargs.pop("data_format")
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
resample = kwargs.pop("resample")
# Make hashable for cache
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
image_mean, image_std, interpolation = self._prepare_process_arguments(
size=size,
crop_size=crop_size,
resample=resample,
image_mean=image_mean,
image_std=image_std,
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
device=images[0].device,
do_resize=kwargs.get("do_resize"),
do_center_crop=kwargs.get("do_center_crop"),
do_rescale=kwargs.get("do_rescale"),
rescale_factor=kwargs.get("rescale_factor"),
do_normalize=kwargs.get("do_normalize"),
return_tensors=kwargs.get("return_tensors"),
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
return self._preprocess(
images=images,
size=size,
crop_size=crop_size,
interpolation=interpolation,
image_mean=image_mean,
image_std=image_std,
**kwargs,
)
# Pop kwargs that are not needed in _preprocess
kwargs.pop("default_to_square")
kwargs.pop("data_format")
return self._preprocess(images=images, **kwargs)
def _preprocess(
self,