Add args support for fast image processors (#37018)

* add args support to fast image processors

* add comment for clarity

* fix-copies

* Handle child class args passed as both args or kwargs in call and preprocess functions

* revert support args passed as kwargs in overwritten preprocess

* fix image processor errors
This commit is contained in:
Yoni Gozlan
2025-05-16 12:01:46 -04:00
committed by GitHub
parent d69945e5fc
commit 0ba95564b7
12 changed files with 68 additions and 71 deletions

View File

@@ -18,11 +18,7 @@ from typing import Any, Optional, TypedDict, Union
import numpy as np
from .image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from .image_transforms import (
convert_to_rgb,
get_resize_output_image_size,
@@ -233,6 +229,9 @@ class BaseImageProcessorFast(BaseImageProcessor):
else:
setattr(self, key, getattr(self, key, None))
# get valid kwargs names
self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
def resize(
self,
image: "torch.Tensor",
@@ -566,12 +565,16 @@ class BaseImageProcessorFast(BaseImageProcessor):
data_format=data_format,
)
def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
return self.preprocess(images, *args, **kwargs)
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
for kwarg_name in self._valid_kwargs_names:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
@@ -603,7 +606,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
kwargs.pop("default_to_square")
kwargs.pop("data_format")
return self._preprocess(images=images, **kwargs)
return self._preprocess(images, *args, **kwargs)
def _preprocess(
self,
@@ -651,6 +654,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_valid_processor_keys", None)
encoder_dict.pop("_valid_kwargs_names", None)
return encoder_dict