Fix rescale normalize inconsistencies in fast image processors (#36388)
* fix fused rescale normalize inconsistencies * fix siglip2 fast image processor * refactor kwargs validation and fused nirmalize rescale * cleanup kwargs handling in preprocess * update new procs after refactor
This commit is contained in:
@@ -40,8 +40,8 @@ from .image_utils import (
|
||||
get_image_type,
|
||||
infer_channel_dimension_format,
|
||||
make_flat_list_of_images,
|
||||
validate_fast_preprocess_arguments,
|
||||
validate_kwargs,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from .processing_utils import Unpack
|
||||
from .utils import (
|
||||
@@ -72,6 +72,49 @@ if is_torchvision_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def validate_fast_preprocess_arguments(
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_pad: Optional[bool] = None,
|
||||
size_divisibility: Optional[int] = None,
|
||||
do_center_crop: Optional[bool] = None,
|
||||
crop_size: Optional[SizeDict] = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[SizeDict] = None,
|
||||
resample: Optional["PILImageResampling"] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
):
|
||||
"""
|
||||
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
|
||||
Raises `ValueError` if arguments incompatibility is caught.
|
||||
"""
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_pad=do_pad,
|
||||
size_divisibility=size_divisibility,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
# Extra checks for ImageProcessorFast
|
||||
if return_tensors is not None and return_tensors != "pt":
|
||||
raise ValueError("Only returning PyTorch tensors is currently supported.")
|
||||
|
||||
if data_format != ChannelDimension.FIRST:
|
||||
raise ValueError("Only channel first data format is currently supported.")
|
||||
|
||||
|
||||
def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
|
||||
"""
|
||||
Squeezes a tensor, but only if the axis specified has dim 1.
|
||||
@@ -380,6 +423,23 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
"""
|
||||
return F.normalize(image, mean, std)
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def _fuse_mean_std_and_rescale_factor(
|
||||
self,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
) -> tuple:
|
||||
if do_rescale and do_normalize:
|
||||
# Fused rescale and normalize
|
||||
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
|
||||
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
|
||||
do_rescale = False
|
||||
return image_mean, image_std, do_rescale
|
||||
|
||||
def rescale_and_normalize(
|
||||
self,
|
||||
images: "torch.Tensor",
|
||||
@@ -392,12 +452,19 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
"""
|
||||
Rescale and normalize images.
|
||||
"""
|
||||
if do_rescale and do_normalize:
|
||||
image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
device=images.device,
|
||||
)
|
||||
# if/elif as we use fused rescale and normalize if both are set to True
|
||||
if do_normalize:
|
||||
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
|
||||
elif do_rescale:
|
||||
images = images * rescale_factor
|
||||
elif do_normalize:
|
||||
images = self.normalize(images, image_mean, image_std)
|
||||
images = self.rescale(images, rescale_factor)
|
||||
|
||||
return images
|
||||
|
||||
@@ -527,25 +594,60 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
|
||||
return processed_images
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def _prepare_process_arguments(
|
||||
def _further_process_kwargs(
|
||||
self,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
|
||||
do_center_crop: bool = None,
|
||||
crop_size: int = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
size: Optional[SizeDict] = None,
|
||||
crop_size: Optional[SizeDict] = None,
|
||||
default_to_square: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
device: Optional["torch.device"] = None,
|
||||
) -> tuple:
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Prepare the arguments for the process method.
|
||||
Update kwargs that need further processing before being validated
|
||||
Can be overridden by subclasses to customize the processing of kwargs.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if size is not None:
|
||||
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
|
||||
if crop_size is not None:
|
||||
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
|
||||
if isinstance(image_mean, list):
|
||||
image_mean = tuple(image_mean)
|
||||
if isinstance(image_std, list):
|
||||
image_std = tuple(image_std)
|
||||
if data_format is None:
|
||||
data_format = ChannelDimension.FIRST
|
||||
|
||||
kwargs["size"] = size
|
||||
kwargs["crop_size"] = crop_size
|
||||
kwargs["default_to_square"] = default_to_square
|
||||
kwargs["image_mean"] = image_mean
|
||||
kwargs["image_std"] = image_std
|
||||
kwargs["data_format"] = data_format
|
||||
|
||||
return kwargs
|
||||
|
||||
def _validate_preprocess_kwargs(
|
||||
self,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, tuple[float]]] = None,
|
||||
image_std: Optional[Union[float, tuple[float]]] = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[SizeDict] = None,
|
||||
do_center_crop: Optional[bool] = None,
|
||||
crop_size: Optional[SizeDict] = None,
|
||||
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
validate the kwargs for the preprocess method.
|
||||
"""
|
||||
validate_fast_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
@@ -562,23 +664,8 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
if do_rescale and do_normalize:
|
||||
# Fused rescale and normalize
|
||||
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
|
||||
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
|
||||
|
||||
interpolation = (
|
||||
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
||||
)
|
||||
|
||||
return image_mean, image_std, interpolation
|
||||
|
||||
@add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS)
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
@@ -589,51 +676,28 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
|
||||
# Prepare input images
|
||||
images = self._prepare_input_images(
|
||||
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||
)
|
||||
|
||||
# Pop kwargs that need further processing or won't be used in _preprocess
|
||||
default_to_square = kwargs.pop("default_to_square")
|
||||
size = kwargs.pop("size")
|
||||
crop_size = kwargs.pop("crop_size")
|
||||
image_mean = kwargs.pop("image_mean")
|
||||
image_std = kwargs.pop("image_std")
|
||||
data_format = kwargs.pop("data_format")
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
resample = kwargs.pop("resample")
|
||||
|
||||
# Make hashable for cache
|
||||
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
|
||||
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
|
||||
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
|
||||
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
|
||||
|
||||
image_mean, image_std, interpolation = self._prepare_process_arguments(
|
||||
size=size,
|
||||
crop_size=crop_size,
|
||||
resample=resample,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
|
||||
device=images[0].device,
|
||||
do_resize=kwargs.get("do_resize"),
|
||||
do_center_crop=kwargs.get("do_center_crop"),
|
||||
do_rescale=kwargs.get("do_rescale"),
|
||||
rescale_factor=kwargs.get("rescale_factor"),
|
||||
do_normalize=kwargs.get("do_normalize"),
|
||||
return_tensors=kwargs.get("return_tensors"),
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
||||
)
|
||||
|
||||
return self._preprocess(
|
||||
images=images,
|
||||
size=size,
|
||||
crop_size=crop_size,
|
||||
interpolation=interpolation,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
**kwargs,
|
||||
)
|
||||
# Pop kwargs that are not needed in _preprocess
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
return self._preprocess(images=images, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user