Fix rescale normalize inconsistencies in fast image processors (#36388)
* fix fused rescale normalize inconsistencies * fix siglip2 fast image processor * refactor kwargs validation and fused nirmalize rescale * cleanup kwargs handling in preprocess * update new procs after refactor
This commit is contained in:
@@ -40,8 +40,8 @@ from .image_utils import (
|
|||||||
get_image_type,
|
get_image_type,
|
||||||
infer_channel_dimension_format,
|
infer_channel_dimension_format,
|
||||||
make_flat_list_of_images,
|
make_flat_list_of_images,
|
||||||
validate_fast_preprocess_arguments,
|
|
||||||
validate_kwargs,
|
validate_kwargs,
|
||||||
|
validate_preprocess_arguments,
|
||||||
)
|
)
|
||||||
from .processing_utils import Unpack
|
from .processing_utils import Unpack
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -72,6 +72,49 @@ if is_torchvision_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=10)
|
||||||
|
def validate_fast_preprocess_arguments(
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_pad: Optional[bool] = None,
|
||||||
|
size_divisibility: Optional[int] = None,
|
||||||
|
do_center_crop: Optional[bool] = None,
|
||||||
|
crop_size: Optional[SizeDict] = None,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Optional[SizeDict] = None,
|
||||||
|
resample: Optional["PILImageResampling"] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
|
||||||
|
Raises `ValueError` if arguments incompatibility is caught.
|
||||||
|
"""
|
||||||
|
validate_preprocess_arguments(
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
do_pad=do_pad,
|
||||||
|
size_divisibility=size_divisibility,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
crop_size=crop_size,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
)
|
||||||
|
# Extra checks for ImageProcessorFast
|
||||||
|
if return_tensors is not None and return_tensors != "pt":
|
||||||
|
raise ValueError("Only returning PyTorch tensors is currently supported.")
|
||||||
|
|
||||||
|
if data_format != ChannelDimension.FIRST:
|
||||||
|
raise ValueError("Only channel first data format is currently supported.")
|
||||||
|
|
||||||
|
|
||||||
def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
|
def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
|
||||||
"""
|
"""
|
||||||
Squeezes a tensor, but only if the axis specified has dim 1.
|
Squeezes a tensor, but only if the axis specified has dim 1.
|
||||||
@@ -380,6 +423,23 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
|||||||
"""
|
"""
|
||||||
return F.normalize(image, mean, std)
|
return F.normalize(image, mean, std)
|
||||||
|
|
||||||
|
@lru_cache(maxsize=10)
|
||||||
|
def _fuse_mean_std_and_rescale_factor(
|
||||||
|
self,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
device: Optional["torch.device"] = None,
|
||||||
|
) -> tuple:
|
||||||
|
if do_rescale and do_normalize:
|
||||||
|
# Fused rescale and normalize
|
||||||
|
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
|
||||||
|
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
|
||||||
|
do_rescale = False
|
||||||
|
return image_mean, image_std, do_rescale
|
||||||
|
|
||||||
def rescale_and_normalize(
|
def rescale_and_normalize(
|
||||||
self,
|
self,
|
||||||
images: "torch.Tensor",
|
images: "torch.Tensor",
|
||||||
@@ -392,12 +452,19 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
|||||||
"""
|
"""
|
||||||
Rescale and normalize images.
|
Rescale and normalize images.
|
||||||
"""
|
"""
|
||||||
if do_rescale and do_normalize:
|
image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
device=images.device,
|
||||||
|
)
|
||||||
|
# if/elif as we use fused rescale and normalize if both are set to True
|
||||||
|
if do_normalize:
|
||||||
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
|
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
|
||||||
elif do_rescale:
|
elif do_rescale:
|
||||||
images = images * rescale_factor
|
images = self.rescale(images, rescale_factor)
|
||||||
elif do_normalize:
|
|
||||||
images = self.normalize(images, image_mean, image_std)
|
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
@@ -527,25 +594,60 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
|||||||
|
|
||||||
return processed_images
|
return processed_images
|
||||||
|
|
||||||
@lru_cache(maxsize=10)
|
def _further_process_kwargs(
|
||||||
def _prepare_process_arguments(
|
|
||||||
self,
|
self,
|
||||||
do_resize: bool = None,
|
size: Optional[SizeDict] = None,
|
||||||
size: Dict[str, int] = None,
|
crop_size: Optional[SizeDict] = None,
|
||||||
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
|
default_to_square: Optional[bool] = None,
|
||||||
do_center_crop: bool = None,
|
|
||||||
crop_size: int = None,
|
|
||||||
do_rescale: bool = None,
|
|
||||||
rescale_factor: float = None,
|
|
||||||
do_normalize: bool = None,
|
|
||||||
image_mean: Optional[Union[float, List[float]]] = None,
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
image_std: Optional[Union[float, List[float]]] = None,
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
data_format: Optional[ChannelDimension] = None,
|
||||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
**kwargs,
|
||||||
device: Optional["torch.device"] = None,
|
) -> dict:
|
||||||
) -> tuple:
|
|
||||||
"""
|
"""
|
||||||
Prepare the arguments for the process method.
|
Update kwargs that need further processing before being validated
|
||||||
|
Can be overridden by subclasses to customize the processing of kwargs.
|
||||||
|
"""
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
if size is not None:
|
||||||
|
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
|
||||||
|
if crop_size is not None:
|
||||||
|
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
|
||||||
|
if isinstance(image_mean, list):
|
||||||
|
image_mean = tuple(image_mean)
|
||||||
|
if isinstance(image_std, list):
|
||||||
|
image_std = tuple(image_std)
|
||||||
|
if data_format is None:
|
||||||
|
data_format = ChannelDimension.FIRST
|
||||||
|
|
||||||
|
kwargs["size"] = size
|
||||||
|
kwargs["crop_size"] = crop_size
|
||||||
|
kwargs["default_to_square"] = default_to_square
|
||||||
|
kwargs["image_mean"] = image_mean
|
||||||
|
kwargs["image_std"] = image_std
|
||||||
|
kwargs["data_format"] = data_format
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _validate_preprocess_kwargs(
|
||||||
|
self,
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, tuple[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, tuple[float]]] = None,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Optional[SizeDict] = None,
|
||||||
|
do_center_crop: Optional[bool] = None,
|
||||||
|
crop_size: Optional[SizeDict] = None,
|
||||||
|
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
validate the kwargs for the preprocess method.
|
||||||
"""
|
"""
|
||||||
validate_fast_preprocess_arguments(
|
validate_fast_preprocess_arguments(
|
||||||
do_rescale=do_rescale,
|
do_rescale=do_rescale,
|
||||||
@@ -562,23 +664,8 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
|||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
if do_rescale and do_normalize:
|
|
||||||
# Fused rescale and normalize
|
|
||||||
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
|
|
||||||
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
|
|
||||||
|
|
||||||
interpolation = (
|
|
||||||
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
|
||||||
)
|
|
||||||
|
|
||||||
return image_mean, image_std, interpolation
|
|
||||||
|
|
||||||
@add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS)
|
@add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS)
|
||||||
def preprocess(
|
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||||
self,
|
|
||||||
images: ImageInput,
|
|
||||||
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
|
|
||||||
) -> BatchFeature:
|
|
||||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
||||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||||
# by the user, it gets its default value from the instance, or is set to None.
|
# by the user, it gets its default value from the instance, or is set to None.
|
||||||
@@ -589,51 +676,28 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
|||||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||||
input_data_format = kwargs.pop("input_data_format")
|
input_data_format = kwargs.pop("input_data_format")
|
||||||
device = kwargs.pop("device")
|
device = kwargs.pop("device")
|
||||||
|
# Prepare input images
|
||||||
images = self._prepare_input_images(
|
images = self._prepare_input_images(
|
||||||
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pop kwargs that need further processing or won't be used in _preprocess
|
# Update kwargs that need further processing before being validated
|
||||||
default_to_square = kwargs.pop("default_to_square")
|
kwargs = self._further_process_kwargs(**kwargs)
|
||||||
size = kwargs.pop("size")
|
|
||||||
crop_size = kwargs.pop("crop_size")
|
# Validate kwargs
|
||||||
image_mean = kwargs.pop("image_mean")
|
self._validate_preprocess_kwargs(**kwargs)
|
||||||
image_std = kwargs.pop("image_std")
|
|
||||||
data_format = kwargs.pop("data_format")
|
# torch resize uses interpolation instead of resample
|
||||||
resample = kwargs.pop("resample")
|
resample = kwargs.pop("resample")
|
||||||
|
kwargs["interpolation"] = (
|
||||||
# Make hashable for cache
|
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
||||||
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
|
|
||||||
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
|
|
||||||
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
|
|
||||||
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
|
|
||||||
|
|
||||||
image_mean, image_std, interpolation = self._prepare_process_arguments(
|
|
||||||
size=size,
|
|
||||||
crop_size=crop_size,
|
|
||||||
resample=resample,
|
|
||||||
image_mean=image_mean,
|
|
||||||
image_std=image_std,
|
|
||||||
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
|
|
||||||
device=images[0].device,
|
|
||||||
do_resize=kwargs.get("do_resize"),
|
|
||||||
do_center_crop=kwargs.get("do_center_crop"),
|
|
||||||
do_rescale=kwargs.get("do_rescale"),
|
|
||||||
rescale_factor=kwargs.get("rescale_factor"),
|
|
||||||
do_normalize=kwargs.get("do_normalize"),
|
|
||||||
return_tensors=kwargs.get("return_tensors"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._preprocess(
|
# Pop kwargs that are not needed in _preprocess
|
||||||
images=images,
|
kwargs.pop("default_to_square")
|
||||||
size=size,
|
kwargs.pop("data_format")
|
||||||
crop_size=crop_size,
|
|
||||||
interpolation=interpolation,
|
return self._preprocess(images=images, **kwargs)
|
||||||
image_mean=image_mean,
|
|
||||||
image_std=image_std,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _preprocess(
|
def _preprocess(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from packaging import version
|
|||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ExplicitEnum,
|
ExplicitEnum,
|
||||||
TensorType,
|
|
||||||
is_av_available,
|
is_av_available,
|
||||||
is_cv2_available,
|
is_cv2_available,
|
||||||
is_decord_available,
|
is_decord_available,
|
||||||
@@ -942,48 +941,6 @@ def validate_preprocess_arguments(
|
|||||||
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
|
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
|
||||||
|
|
||||||
|
|
||||||
def validate_fast_preprocess_arguments(
|
|
||||||
do_rescale: Optional[bool] = None,
|
|
||||||
rescale_factor: Optional[float] = None,
|
|
||||||
do_normalize: Optional[bool] = None,
|
|
||||||
image_mean: Optional[Union[float, List[float]]] = None,
|
|
||||||
image_std: Optional[Union[float, List[float]]] = None,
|
|
||||||
do_pad: Optional[bool] = None,
|
|
||||||
size_divisibility: Optional[int] = None,
|
|
||||||
do_center_crop: Optional[bool] = None,
|
|
||||||
crop_size: Optional[Dict[str, int]] = None,
|
|
||||||
do_resize: Optional[bool] = None,
|
|
||||||
size: Optional[Dict[str, int]] = None,
|
|
||||||
resample: Optional["PILImageResampling"] = None,
|
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
||||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
|
|
||||||
Raises `ValueError` if arguments incompatibility is caught.
|
|
||||||
"""
|
|
||||||
validate_preprocess_arguments(
|
|
||||||
do_rescale=do_rescale,
|
|
||||||
rescale_factor=rescale_factor,
|
|
||||||
do_normalize=do_normalize,
|
|
||||||
image_mean=image_mean,
|
|
||||||
image_std=image_std,
|
|
||||||
do_pad=do_pad,
|
|
||||||
size_divisibility=size_divisibility,
|
|
||||||
do_center_crop=do_center_crop,
|
|
||||||
crop_size=crop_size,
|
|
||||||
do_resize=do_resize,
|
|
||||||
size=size,
|
|
||||||
resample=resample,
|
|
||||||
)
|
|
||||||
# Extra checks for ImageProcessorFast
|
|
||||||
if return_tensors is not None and return_tensors != "pt":
|
|
||||||
raise ValueError("Only returning PyTorch tensors is currently supported.")
|
|
||||||
|
|
||||||
if data_format != ChannelDimension.FIRST:
|
|
||||||
raise ValueError("Only channel first data format is currently supported.")
|
|
||||||
|
|
||||||
|
|
||||||
# In the future we can add a TF implementation here when we have TF models.
|
# In the future we can add a TF implementation here when we have TF models.
|
||||||
class ImageFeatureExtractionMixin:
|
class ImageFeatureExtractionMixin:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -691,15 +691,8 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
|
|||||||
target_size=resized_image.size()[-2:],
|
target_size=resized_image.size()[-2:],
|
||||||
)
|
)
|
||||||
image = resized_image
|
image = resized_image
|
||||||
|
# Fused rescale and normalize
|
||||||
if do_rescale and do_normalize:
|
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
||||||
# fused rescale and normalize
|
|
||||||
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
|
|
||||||
elif do_rescale:
|
|
||||||
image = image * rescale_factor
|
|
||||||
elif do_normalize:
|
|
||||||
image = F.normalize(image, image_mean, image_std)
|
|
||||||
|
|
||||||
if do_convert_annotations and annotations is not None:
|
if do_convert_annotations and annotations is not None:
|
||||||
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
||||||
|
|
||||||
|
|||||||
@@ -716,15 +716,8 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
|
|||||||
target_size=resized_image.size()[-2:],
|
target_size=resized_image.size()[-2:],
|
||||||
)
|
)
|
||||||
image = resized_image
|
image = resized_image
|
||||||
|
# Fused rescale and normalize
|
||||||
if do_rescale and do_normalize:
|
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
||||||
# fused rescale and normalize
|
|
||||||
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
|
|
||||||
elif do_rescale:
|
|
||||||
image = image * rescale_factor
|
|
||||||
elif do_normalize:
|
|
||||||
image = F.normalize(image, image_mean, image_std)
|
|
||||||
|
|
||||||
if do_convert_annotations and annotations is not None:
|
if do_convert_annotations and annotations is not None:
|
||||||
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from ...image_processing_utils_fast import (
|
|||||||
BaseImageProcessorFast,
|
BaseImageProcessorFast,
|
||||||
BatchFeature,
|
BatchFeature,
|
||||||
DefaultFastImageProcessorKwargs,
|
DefaultFastImageProcessorKwargs,
|
||||||
get_size_dict,
|
|
||||||
group_images_by_shape,
|
group_images_by_shape,
|
||||||
reorder_images,
|
reorder_images,
|
||||||
)
|
)
|
||||||
@@ -37,7 +36,6 @@ from ...image_utils import (
|
|||||||
SizeDict,
|
SizeDict,
|
||||||
get_image_size,
|
get_image_size,
|
||||||
make_nested_list_of_images,
|
make_nested_list_of_images,
|
||||||
validate_kwargs,
|
|
||||||
)
|
)
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -255,61 +253,7 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
|||||||
images: ImageInput,
|
images: ImageInput,
|
||||||
**kwargs: Unpack[Gemma3FastImageProcessorKwargs],
|
**kwargs: Unpack[Gemma3FastImageProcessorKwargs],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
return super().preprocess(images, **kwargs)
|
||||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
|
||||||
# by the user, it gets its default value from the instance, or is set to None.
|
|
||||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
|
||||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
|
||||||
|
|
||||||
# Extract parameters that are only used for preparing the input images
|
|
||||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
|
||||||
input_data_format = kwargs.pop("input_data_format")
|
|
||||||
device = kwargs.pop("device")
|
|
||||||
|
|
||||||
images = self._prepare_input_images(
|
|
||||||
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pop kwargs that need further processing or won't be used in _preprocess
|
|
||||||
default_to_square = kwargs.pop("default_to_square")
|
|
||||||
size = kwargs.pop("size")
|
|
||||||
crop_size = kwargs.pop("crop_size")
|
|
||||||
image_mean = kwargs.pop("image_mean")
|
|
||||||
image_std = kwargs.pop("image_std")
|
|
||||||
data_format = kwargs.pop("data_format")
|
|
||||||
resample = kwargs.pop("resample")
|
|
||||||
|
|
||||||
# Make hashable for cache
|
|
||||||
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
|
|
||||||
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
|
|
||||||
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
|
|
||||||
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
|
|
||||||
|
|
||||||
image_mean, image_std, interpolation = self._prepare_process_arguments(
|
|
||||||
size=size,
|
|
||||||
crop_size=crop_size,
|
|
||||||
resample=resample,
|
|
||||||
image_mean=image_mean,
|
|
||||||
image_std=image_std,
|
|
||||||
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
|
|
||||||
device=images[0][0].device,
|
|
||||||
do_resize=kwargs.get("do_resize"),
|
|
||||||
do_center_crop=kwargs.get("do_center_crop"),
|
|
||||||
do_rescale=kwargs.get("do_rescale"),
|
|
||||||
rescale_factor=kwargs.get("rescale_factor"),
|
|
||||||
do_normalize=kwargs.get("do_normalize"),
|
|
||||||
return_tensors=kwargs.get("return_tensors"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._preprocess(
|
|
||||||
images=images,
|
|
||||||
size=size,
|
|
||||||
crop_size=crop_size,
|
|
||||||
interpolation=interpolation,
|
|
||||||
image_mean=image_mean,
|
|
||||||
image_std=image_std,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _preprocess(
|
def _preprocess(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ from ...utils import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torchvision_available,
|
is_torchvision_available,
|
||||||
is_torchvision_v2_available,
|
is_torchvision_v2_available,
|
||||||
is_vision_available,
|
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from .image_processing_qwen2_vl import smart_resize
|
from .image_processing_qwen2_vl import smart_resize
|
||||||
@@ -58,12 +57,13 @@ from .image_processing_qwen2_vl import smart_resize
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if is_vision_available():
|
|
||||||
pass
|
|
||||||
|
|
||||||
if is_torchvision_v2_available():
|
if is_torchvision_available():
|
||||||
|
from ...image_utils import pil_torch_interpolation_mapping
|
||||||
|
|
||||||
|
if is_torchvision_v2_available():
|
||||||
from torchvision.transforms.v2 import functional as F
|
from torchvision.transforms.v2 import functional as F
|
||||||
elif is_torchvision_available():
|
else:
|
||||||
from torchvision.transforms import functional as F
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -311,19 +311,22 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
image_mean = tuple(image_mean) if image_mean is not None else None
|
image_mean = tuple(image_mean) if image_mean is not None else None
|
||||||
image_std = tuple(image_std) if image_std is not None else None
|
image_std = tuple(image_std) if image_std is not None else None
|
||||||
|
|
||||||
image_mean, image_std, interpolation = self._prepare_process_arguments(
|
self._validate_preprocess_kwargs(
|
||||||
do_resize=do_resize,
|
|
||||||
size=size,
|
|
||||||
resample=resample,
|
|
||||||
do_rescale=do_rescale,
|
do_rescale=do_rescale,
|
||||||
rescale_factor=rescale_factor,
|
rescale_factor=rescale_factor,
|
||||||
do_normalize=do_normalize,
|
do_normalize=do_normalize,
|
||||||
image_mean=image_mean,
|
image_mean=image_mean,
|
||||||
image_std=image_std,
|
image_std=image_std,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
interpolation = (
|
||||||
|
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
||||||
|
)
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
images = make_flat_list_of_images(images)
|
images = make_flat_list_of_images(images)
|
||||||
if videos is not None:
|
if videos is not None:
|
||||||
|
|||||||
@@ -486,15 +486,8 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
|
|||||||
target_size=resized_image.size()[-2:],
|
target_size=resized_image.size()[-2:],
|
||||||
)
|
)
|
||||||
image = resized_image
|
image = resized_image
|
||||||
|
# Fused rescale and normalize
|
||||||
if do_rescale and do_normalize:
|
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
||||||
# fused rescale and normalize
|
|
||||||
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
|
|
||||||
elif do_rescale:
|
|
||||||
image = image * rescale_factor
|
|
||||||
elif do_normalize:
|
|
||||||
image = F.normalize(image, image_mean, image_std)
|
|
||||||
|
|
||||||
if do_convert_annotations and annotations is not None:
|
if do_convert_annotations and annotations is not None:
|
||||||
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
||||||
|
|
||||||
|
|||||||
@@ -266,15 +266,8 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
|
|||||||
target_size=resized_image.size()[-2:],
|
target_size=resized_image.size()[-2:],
|
||||||
)
|
)
|
||||||
image = resized_image
|
image = resized_image
|
||||||
|
# Fused rescale and normalize
|
||||||
if do_rescale and do_normalize:
|
image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std)
|
||||||
# fused rescale and normalize
|
|
||||||
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
|
|
||||||
elif do_rescale:
|
|
||||||
image = image * rescale_factor
|
|
||||||
elif do_normalize:
|
|
||||||
image = F.normalize(image, image_mean, image_std)
|
|
||||||
|
|
||||||
if do_convert_annotations and annotations is not None:
|
if do_convert_annotations and annotations is not None:
|
||||||
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fast Image processor class for SigLIP2."""
|
"""Fast Image processor class for SigLIP2."""
|
||||||
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -117,11 +116,10 @@ class Siglip2ImageProcessorFast(BaseImageProcessorFast):
|
|||||||
def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
|
def __init__(self, **kwargs: Unpack[Siglip2FastImageProcessorKwargs]):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@lru_cache(maxsize=10)
|
def _validate_preprocess_kwargs(self, **kwargs) -> tuple:
|
||||||
def _prepare_process_arguments(self, **kwargs) -> tuple:
|
|
||||||
# Remove do_resize from kwargs to not raise an error as size is None
|
# Remove do_resize from kwargs to not raise an error as size is None
|
||||||
kwargs.pop("do_resize", None)
|
kwargs.pop("do_resize", None)
|
||||||
return super()._prepare_process_arguments(**kwargs)
|
return super()._validate_preprocess_kwargs(**kwargs)
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
|
||||||
|
|||||||
Reference in New Issue
Block a user