From ae1cffaf3cd42d0ab1d7529e3b3118725bca0bcf Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 29 Nov 2022 10:38:01 +0000 Subject: [PATCH] Add Donut image processor (#20425) * Add Donut image processor * Update src/transformers/image_transforms.py Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com> * Fix docstrings * Full var names in docstring Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com> --- docs/source/en/model_doc/donut.mdx | 5 + src/transformers/__init__.py | 4 +- src/transformers/image_transforms.py | 12 +- .../models/auto/image_processing_auto.py | 1 + src/transformers/models/donut/__init__.py | 2 + .../models/donut/feature_extraction_donut.py | 192 +------- .../models/donut/image_processing_donut.py | 426 ++++++++++++++++++ .../models/donut/processing_donut.py | 36 +- .../utils/dummy_vision_objects.py | 7 + .../donut/test_feature_extraction_donut.py | 28 +- 10 files changed, 498 insertions(+), 215 deletions(-) create mode 100644 src/transformers/models/donut/image_processing_donut.py diff --git a/docs/source/en/model_doc/donut.mdx b/docs/source/en/model_doc/donut.mdx index 13bfcfd65b..62ce32fd9c 100644 --- a/docs/source/en/model_doc/donut.mdx +++ b/docs/source/en/model_doc/donut.mdx @@ -194,6 +194,11 @@ We refer to the [tutorial notebooks](https://github.com/NielsRogge/Transformers- [[autodoc]] DonutSwinConfig +## DonutImageProcessor + +[[autodoc]] DonutImageProcessor + - preprocess + ## DonutFeatureExtractor [[autodoc]] DonutFeatureExtractor diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 714ce483ea..107ee91861 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -727,7 +727,7 @@ else: _import_structure["models.deit"].extend(["DeiTFeatureExtractor", "DeiTImageProcessor"]) _import_structure["models.detr"].append("DetrFeatureExtractor") _import_structure["models.conditional_detr"].append("ConditionalDetrFeatureExtractor") - _import_structure["models.donut"].append("DonutFeatureExtractor") + _import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"]) _import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"]) _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaProcessor", "FlavaImageProcessor"]) _import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"]) @@ -3853,7 +3853,7 @@ if TYPE_CHECKING: from .models.deformable_detr import DeformableDetrFeatureExtractor from .models.deit import DeiTFeatureExtractor, DeiTImageProcessor from .models.detr import DetrFeatureExtractor - from .models.donut import DonutFeatureExtractor + from .models.donut import DonutFeatureExtractor, DonutImageProcessor from .models.dpt import DPTFeatureExtractor, DPTImageProcessor from .models.flava import FlavaFeatureExtractor, FlavaImageProcessor, FlavaProcessor from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index a0a52d6fd4..d9e2e87458 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -223,11 +223,12 @@ def resize( image, size: Tuple[int, int], resample=PILImageResampling.BILINEAR, + reducing_gap: Optional[int] = None, data_format: Optional[ChannelDimension] = None, return_numpy: bool = True, ) -> np.ndarray: """ - Resizes `image` to (h, w) specified by `size` using the PIL library. + Resizes `image` to `(height, width)` specified by `size` using the PIL library. Args: image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): @@ -236,8 +237,11 @@ def resize( The size to use for resizing the image. resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`): The filter to user for resampling. + reducing_gap (`int`, *optional*): + Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to + the fair resampling. See corresponding Pillow documentation for more details. data_format (`ChannelDimension`, *optional*): - The channel dimension format of the output image. If `None`, will use the inferred format from the input. + The channel dimension format of the output image. If unset, will use the inferred format from the input. return_numpy (`bool`, *optional*, defaults to `True`): Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is returned. @@ -260,7 +264,7 @@ def resize( image = to_pil_image(image) height, width = size # PIL images are in the format (width, height) - resized_image = image.resize((width, height), resample=resample) + resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap) if return_numpy: resized_image = np.array(resized_image) @@ -290,7 +294,7 @@ def normalize( std (`float` or `Iterable[float]`): The standard deviation to use for normalization. data_format (`ChannelDimension`, *optional*): - The channel dimension format of the output image. If `None`, will use the inferred format from the input. + The channel dimension format of the output image. If unset, will use the inferred format from the input. """ if isinstance(image, PIL.Image.Image): warnings.warn( diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 785aaac0d0..df77c52314 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -44,6 +44,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( ("data2vec-vision", "BeitImageProcessor"), ("deit", "DeiTImageProcessor"), ("dinat", "ViTImageProcessor"), + ("donut-swin", "DonutImageProcessor"), ("dpt", "DPTImageProcessor"), ("flava", "FlavaImageProcessor"), ("glpn", "GLPNImageProcessor"), diff --git a/src/transformers/models/donut/__init__.py b/src/transformers/models/donut/__init__.py index a01f6b11a9..ee003aa4ac 100644 --- a/src/transformers/models/donut/__init__.py +++ b/src/transformers/models/donut/__init__.py @@ -44,6 +44,7 @@ except OptionalDependencyNotAvailable: pass else: _import_structure["feature_extraction_donut"] = ["DonutFeatureExtractor"] + _import_structure["image_processing_donut"] = ["DonutImageProcessor"] if TYPE_CHECKING: @@ -69,6 +70,7 @@ if TYPE_CHECKING: pass else: from .feature_extraction_donut import DonutFeatureExtractor + from .image_processing_donut import DonutImageProcessor else: import sys diff --git a/src/transformers/models/donut/feature_extraction_donut.py b/src/transformers/models/donut/feature_extraction_donut.py index 4bbc74193c..d29c533ea8 100644 --- a/src/transformers/models/donut/feature_extraction_donut.py +++ b/src/transformers/models/donut/feature_extraction_donut.py @@ -14,197 +14,11 @@ # limitations under the License. """Feature extractor class for Donut.""" -from typing import Optional, Tuple, Union - -import numpy as np -from PIL import Image, ImageOps - -from transformers.image_utils import PILImageResampling - -from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin -from ...image_utils import ( - IMAGENET_STANDARD_MEAN, - IMAGENET_STANDARD_STD, - ImageFeatureExtractionMixin, - ImageInput, - is_torch_tensor, -) -from ...utils import TensorType, logging +from ...utils import logging +from .image_processing_donut import DonutImageProcessor logger = logging.get_logger(__name__) -class DonutFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): - r""" - Constructs a Donut feature extractor. - - This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users - should refer to this superclass for more information regarding those methods. - - Args: - do_resize (`bool`, *optional*, defaults to `True`): - Whether to resize the shorter edge of the input to the minimum value of a certain `size`. - size (`Tuple(int)`, *optional*, defaults to [1920, 2560]): - Resize the shorter edge of the input to the minimum value of the given size. Should be a tuple of (width, - height). Only has an effect if `do_resize` is set to `True`. - resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`): - An optional resampling filter. This can be one of `PILImageResampling.NEAREST`, `PILImageResampling.BOX`, - `PILImageResampling.BILINEAR`, `PILImageResampling.HAMMING`, `PILImageResampling.BICUBIC` or - `PILImageResampling.LANCZOS`. Only has an effect if `do_resize` is set to `True`. - do_thumbnail (`bool`, *optional*, defaults to `True`): - Whether to thumbnail the input to the given `size`. - do_align_long_axis (`bool`, *optional*, defaults to `False`): - Whether to rotate the input if the height is greater than width. - do_pad (`bool`, *optional*, defaults to `True`): - Whether or not to pad the input to `size`. - do_normalize (`bool`, *optional*, defaults to `True`): - Whether or not to normalize the input with mean and standard deviation. - image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`): - The sequence of means for each channel, to be used when normalizing images. - image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`): - The sequence of standard deviations for each channel, to be used when normalizing images. - - """ - - model_input_names = ["pixel_values"] - - def __init__( - self, - do_resize=True, - size=[1920, 2560], - resample=PILImageResampling.BILINEAR, - do_thumbnail=True, - do_align_long_axis=False, - do_pad=True, - do_normalize=True, - image_mean=None, - image_std=None, - **kwargs - ): - super().__init__(**kwargs) - self.do_resize = do_resize - self.size = size - self.resample = resample - self.do_thumbnail = do_thumbnail - self.do_align_long_axis = do_align_long_axis - self.do_pad = do_pad - self.do_normalize = do_normalize - self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN - self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD - - def rotate_image(self, image, size): - if not isinstance(image, Image.Image): - image = self.to_pil_image(image) - - if (size[1] > size[0] and image.width > image.height) or (size[1] < size[0] and image.width < image.height): - image = self.rotate(image, angle=-90, expand=True) - - return image - - def thumbnail(self, image, size): - if not isinstance(image, Image.Image): - image = self.to_pil_image(image) - - image.thumbnail((size[0], size[1])) - - return image - - def pad(self, image: Image.Image, size: Tuple[int, int], random_padding: bool = False) -> Image.Image: - delta_width = size[0] - image.width - delta_height = size[1] - image.height - - if random_padding: - pad_width = np.random.randint(low=0, high=delta_width + 1) - pad_height = np.random.randint(low=0, high=delta_height + 1) - else: - pad_width = delta_width // 2 - pad_height = delta_height // 2 - - padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height) - return ImageOps.expand(image, padding) - - def __call__( - self, - images: ImageInput, - return_tensors: Optional[Union[str, TensorType]] = None, - random_padding=False, - **kwargs - ) -> BatchFeature: - """ - Main method to prepare for the model one or several image(s). - - - - NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass - PIL images. - - - - Args: - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a - number of channels, H and W are image height and width. - - random_padding (`bool`, *optional*, defaults to `False`): - Whether to randomly pad the input to `size`. - - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - - Returns: - [`BatchFeature`]: A [`BatchFeature`] with the following fields: - - - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, - width). - """ - # Input type checking for clearer error - valid_images = False - - # Check that images has a valid type - if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): - valid_images = True - elif isinstance(images, (list, tuple)): - if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): - valid_images = True - - if not valid_images: - raise ValueError( - "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), " - "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." - ) - - is_batched = bool( - isinstance(images, (list, tuple)) - and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) - ) - - if not is_batched: - images = [images] - - # transformations (rotating + resizing + thumbnailing + padding + normalization) - if self.do_align_long_axis: - images = [self.rotate_image(image, self.size) for image in images] - if self.do_resize and self.size is not None: - images = [ - self.resize(image=image, size=min(self.size), resample=self.resample, default_to_square=False) - for image in images - ] - if self.do_thumbnail and self.size is not None: - images = [self.thumbnail(image=image, size=self.size) for image in images] - if self.do_pad and self.size is not None: - images = [self.pad(image=image, size=self.size, random_padding=random_padding) for image in images] - if self.do_normalize: - images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] - - # return as BatchFeature - data = {"pixel_values": images} - encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) - - return encoded_inputs +DonutFeatureExtractor = DonutImageProcessor diff --git a/src/transformers/models/donut/image_processing_donut.py b/src/transformers/models/donut/image_processing_donut.py new file mode 100644 index 0000000000..a6434bccde --- /dev/null +++ b/src/transformers/models/donut/image_processing_donut.py @@ -0,0 +1,426 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Donut.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + get_resize_output_image_size, + normalize, + pad, + rescale, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + is_batched, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, logging +from ...utils.import_utils import is_vision_available + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class DonutImageProcessor(BaseImageProcessor): + r""" + Constructs a Donut image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize: + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Image standard deviation. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_thumbnail: bool = True, + do_align_long_axis: bool = False, + do_pad: bool = True, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs + ) -> None: + super().__init__(**kwargs) + + size = size if size is not None else {"height": 2560, "width": 1920} + if isinstance(size, (tuple, list)): + # The previous feature extractor size parameter was in (width, height) format + size = size[::-1] + size = get_size_dict(size) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_thumbnail = do_thumbnail + self.do_align_long_axis = do_align_long_axis + self.do_pad = do_pad + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + def align_long_axis( + self, image: np.ndarray, size: Dict[str, int], data_format: Optional[Union[str, ChannelDimension]] = None + ) -> np.ndarray: + """ + Align the long axis of the image to the longest axis of the specified size. + + Args: + image (`np.ndarray`): + The image to be aligned. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to align the long axis to. + + Returns: + `np.ndarray`: The aligned image. + """ + input_height, input_width = get_image_size(image) + output_height, output_width = size["height"], size["width"] + + if (output_width < output_height and input_width > input_height) or ( + output_width > output_height and input_width < input_height + ): + image = np.rot90(image, 3) + + if data_format is not None: + image = to_channel_dimension_format(image, data_format) + + return image + + def rotate_image(self, *args, **kwargs): + logger.info( + "rotate_image is deprecated and will be removed in version 4.27. Please use align_long_axis instead." + ) + return self.align_long_axis(*args, **kwargs) + + def pad_image( + self, + image: np.ndarray, + size: Dict[str, int], + random_padding: bool = False, + data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Pad the image to the specified size. + + Args: + image (`np.ndarray`): + The image to be padded. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to pad the image to. + random_padding (`bool`, *optional*, defaults to `False`): + Whether to use random padding or not. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + """ + output_height, output_width = size["height"], size["width"] + input_height, input_width = get_image_size(image) + + delta_width = output_width - input_width + delta_height = output_height - input_height + + if random_padding: + pad_top = np.random.randint(low=0, high=delta_height + 1) + pad_left = np.random.randint(low=0, high=delta_width + 1) + else: + pad_top = delta_height // 2 + pad_left = delta_width // 2 + + pad_bottom = delta_height - pad_top + pad_right = delta_width - pad_left + + padding = ((pad_top, pad_bottom), (pad_left, pad_right)) + return pad(image, padding, data_format=data_format) + + def pad(self, *args, **kwargs): + logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.") + return self.pad_image(*args, **kwargs) + + def thumbnail( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs + ) -> np.ndarray: + """ + Resize the image to the specified size using thumbnail method. + + Args: + image (`np.ndarray`): + The image to be resized. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to resize the image to. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + The resampling filter to use. + data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + """ + output_size = (size["height"], size["width"]) + return resize(image, size=output_size, resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs) + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + """ + size = get_size_dict(size) + shortest_edge = min(size["height"], size["width"]) + output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False) + return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) + + def rescale( + self, + image: np.ndarray, + scale: Union[int, float], + data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs + ): + """ + Rescale an image by a scale factor. image = image * scale. + + Args: + image (`np.ndarray`): + Image to rescale. + scale (`int` or `float`): + Scale to apply to the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + """ + return rescale(image, scale=scale, data_format=data_format, **kwargs) + + def normalize( + self, + image: np.ndarray, + mean: Union[float, List[float]], + std: Union[float, List[float]], + data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs + ) -> np.ndarray: + """ + Normalize an image. image = (image - image_mean) / image_std. + + Args: + image (`np.ndarray`): + Image to normalize. + image_mean (`float` or `List[float]`): + Image mean. + image_std (`float` or `List[float]`): + Image standard deviation. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + """ + return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs) + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_thumbnail: bool = None, + do_align_long_axis: bool = None, + do_pad: bool = None, + random_padding: bool = False, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + **kwargs + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to min(size["height"], + size["width"]) with the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random + amont of padding on each size, up to the largest image size in the batch. Otherwise, all images are + padded to the largest image size in the batch. + random_padding (`bool`, *optional*, defaults to `self.random_padding`): + Whether to use random padding when padding the image. If `True`, each image in the batch with be padded + with a random amount of padding on each side up to the size of the largest image in the batch. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image pixel values. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: defaults to the channel dimension format of the input image. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + if isinstance(size, (tuple, list)): + # Previous feature extractor had size in (width, height) format + size = size[::-1] + size = get_size_dict(size) + resample = resample if resample is not None else self.resample + do_thumbnail = do_thumbnail if do_thumbnail is not None else self.do_thumbnail + do_align_long_axis = do_align_long_axis if do_align_long_axis is not None else self.do_align_long_axis + do_pad = do_pad if do_pad is not None else self.do_pad + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + if not is_batched(images): + images = [images] + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None: + raise ValueError("Size must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_pad and size is None: + raise ValueError("Size must be specified if do_pad is True.") + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("Image mean and std must be specified if do_normalize is True.") + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_align_long_axis: + images = [self.align_long_axis(image) for image in images] + + if do_resize: + images = [self.resize(image=image, size=size, resample=resample) for image in images] + + if do_thumbnail: + images = [self.thumbnail(image=image, size=size) for image in images] + + if do_pad: + images = [self.pad(image=image, size=size, random_padding=random_padding) for image in images] + + if do_rescale: + images = [self.rescale(image=image, scale=rescale_factor) for image in images] + + if do_normalize: + images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] + + images = [to_channel_dimension_format(image, data_format) for image in images] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/models/donut/processing_donut.py b/src/transformers/models/donut/processing_donut.py index da9e89c1d8..d45a91203c 100644 --- a/src/transformers/models/donut/processing_donut.py +++ b/src/transformers/models/donut/processing_donut.py @@ -37,12 +37,27 @@ class DonutProcessor(ProcessorMixin): tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]): An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input. """ - feature_extractor_class = "AutoFeatureExtractor" + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" - def __init__(self, feature_extractor, tokenizer): - super().__init__(feature_extractor, tokenizer) - self.current_processor = self.feature_extractor + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v4.27, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor self._in_target_context_manager = False def __call__(self, *args, **kwargs): @@ -66,7 +81,7 @@ class DonutProcessor(ProcessorMixin): raise ValueError("You need to specify either an `images` or `text` input to process.") if images is not None: - inputs = self.feature_extractor(images, *args, **kwargs) + inputs = self.image_processor(images, *args, **kwargs) if text is not None: encodings = self.tokenizer(text, **kwargs) @@ -105,7 +120,7 @@ class DonutProcessor(ProcessorMixin): self._in_target_context_manager = True self.current_processor = self.tokenizer yield - self.current_processor = self.feature_extractor + self.current_processor = self.image_processor self._in_target_context_manager = False def token2json(self, tokens, is_inner_value=False, added_vocab=None): @@ -157,3 +172,12 @@ class DonutProcessor(ProcessorMixin): return [output] if is_inner_value else output else: return [] if is_inner_value else {"text_sequence": tokens} + + @property + def feature_extractor_class(self): + warnings.warn( + "`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`" + " instead.", + FutureWarning, + ) + return self.image_processor_class diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 7ce1f18670..fb35aad69a 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -113,6 +113,13 @@ class DonutFeatureExtractor(metaclass=DummyObject): requires_backends(self, ["vision"]) +class DonutImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class DPTFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/donut/test_feature_extraction_donut.py b/tests/models/donut/test_feature_extraction_donut.py index 38ccbf2075..1a8a8d0936 100644 --- a/tests/models/donut/test_feature_extraction_donut.py +++ b/tests/models/donut/test_feature_extraction_donut.py @@ -43,7 +43,7 @@ class DonutFeatureExtractionTester(unittest.TestCase): min_resolution=30, max_resolution=400, do_resize=True, - size=[20, 18], + size=None, do_thumbnail=True, do_align_axis=False, do_pad=True, @@ -58,7 +58,7 @@ class DonutFeatureExtractionTester(unittest.TestCase): self.min_resolution = min_resolution self.max_resolution = max_resolution self.do_resize = do_resize - self.size = size + self.size = size if size is not None else {"height": 18, "width": 20} self.do_thumbnail = do_thumbnail self.do_align_axis = do_align_axis self.do_pad = do_pad @@ -121,8 +121,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ( 1, self.feature_extract_tester.num_channels, - self.feature_extract_tester.size[1], - self.feature_extract_tester.size[0], + self.feature_extract_tester.size["height"], + self.feature_extract_tester.size["width"], ), ) @@ -133,8 +133,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ( self.feature_extract_tester.batch_size, self.feature_extract_tester.num_channels, - self.feature_extract_tester.size[1], - self.feature_extract_tester.size[0], + self.feature_extract_tester.size["height"], + self.feature_extract_tester.size["width"], ), ) @@ -153,8 +153,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ( 1, self.feature_extract_tester.num_channels, - self.feature_extract_tester.size[1], - self.feature_extract_tester.size[0], + self.feature_extract_tester.size["height"], + self.feature_extract_tester.size["width"], ), ) @@ -165,8 +165,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ( self.feature_extract_tester.batch_size, self.feature_extract_tester.num_channels, - self.feature_extract_tester.size[1], - self.feature_extract_tester.size[0], + self.feature_extract_tester.size["height"], + self.feature_extract_tester.size["width"], ), ) @@ -185,8 +185,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ( 1, self.feature_extract_tester.num_channels, - self.feature_extract_tester.size[1], - self.feature_extract_tester.size[0], + self.feature_extract_tester.size["height"], + self.feature_extract_tester.size["width"], ), ) @@ -197,7 +197,7 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test ( self.feature_extract_tester.batch_size, self.feature_extract_tester.num_channels, - self.feature_extract_tester.size[1], - self.feature_extract_tester.size[0], + self.feature_extract_tester.size["height"], + self.feature_extract_tester.size["width"], ), )