Add args support for fast image processors (#37018)
* add args support to fast image processors * add comment for clarity * fix-copies * Handle child class args passed as both args or kwargs in call and preprocess functions * revert support args passed as kwargs in overwritten preprocess * fix image processor errors
This commit is contained in:
@@ -18,11 +18,7 @@ from typing import Any, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .image_processing_utils import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
get_size_dict,
|
||||
)
|
||||
from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from .image_transforms import (
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
@@ -233,6 +229,9 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
else:
|
||||
setattr(self, key, getattr(self, key, None))
|
||||
|
||||
# get valid kwargs names
|
||||
self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
@@ -566,12 +565,16 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||
return self.preprocess(images, *args, **kwargs)
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
||||
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
for kwarg_name in self._valid_kwargs_names:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
@@ -603,7 +606,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
return self._preprocess(images=images, **kwargs)
|
||||
return self._preprocess(images, *args, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
@@ -651,6 +654,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
def to_dict(self):
|
||||
encoder_dict = super().to_dict()
|
||||
encoder_dict.pop("_valid_processor_keys", None)
|
||||
encoder_dict.pop("_valid_kwargs_names", None)
|
||||
return encoder_dict
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user