Add Donut image processor (#20425)
* Add Donut image processor * Update src/transformers/image_transforms.py Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com> * Fix docstrings * Full var names in docstring Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com>
This commit is contained in:
@@ -194,6 +194,11 @@ We refer to the [tutorial notebooks](https://github.com/NielsRogge/Transformers-
|
|||||||
|
|
||||||
[[autodoc]] DonutSwinConfig
|
[[autodoc]] DonutSwinConfig
|
||||||
|
|
||||||
|
## DonutImageProcessor
|
||||||
|
|
||||||
|
[[autodoc]] DonutImageProcessor
|
||||||
|
- preprocess
|
||||||
|
|
||||||
## DonutFeatureExtractor
|
## DonutFeatureExtractor
|
||||||
|
|
||||||
[[autodoc]] DonutFeatureExtractor
|
[[autodoc]] DonutFeatureExtractor
|
||||||
|
|||||||
@@ -727,7 +727,7 @@ else:
|
|||||||
_import_structure["models.deit"].extend(["DeiTFeatureExtractor", "DeiTImageProcessor"])
|
_import_structure["models.deit"].extend(["DeiTFeatureExtractor", "DeiTImageProcessor"])
|
||||||
_import_structure["models.detr"].append("DetrFeatureExtractor")
|
_import_structure["models.detr"].append("DetrFeatureExtractor")
|
||||||
_import_structure["models.conditional_detr"].append("ConditionalDetrFeatureExtractor")
|
_import_structure["models.conditional_detr"].append("ConditionalDetrFeatureExtractor")
|
||||||
_import_structure["models.donut"].append("DonutFeatureExtractor")
|
_import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"])
|
||||||
_import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
|
_import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
|
||||||
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaProcessor", "FlavaImageProcessor"])
|
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaProcessor", "FlavaImageProcessor"])
|
||||||
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
|
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
|
||||||
@@ -3853,7 +3853,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.deformable_detr import DeformableDetrFeatureExtractor
|
from .models.deformable_detr import DeformableDetrFeatureExtractor
|
||||||
from .models.deit import DeiTFeatureExtractor, DeiTImageProcessor
|
from .models.deit import DeiTFeatureExtractor, DeiTImageProcessor
|
||||||
from .models.detr import DetrFeatureExtractor
|
from .models.detr import DetrFeatureExtractor
|
||||||
from .models.donut import DonutFeatureExtractor
|
from .models.donut import DonutFeatureExtractor, DonutImageProcessor
|
||||||
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
|
from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
|
||||||
from .models.flava import FlavaFeatureExtractor, FlavaImageProcessor, FlavaProcessor
|
from .models.flava import FlavaFeatureExtractor, FlavaImageProcessor, FlavaProcessor
|
||||||
from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor
|
from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor
|
||||||
|
|||||||
@@ -223,11 +223,12 @@ def resize(
|
|||||||
image,
|
image,
|
||||||
size: Tuple[int, int],
|
size: Tuple[int, int],
|
||||||
resample=PILImageResampling.BILINEAR,
|
resample=PILImageResampling.BILINEAR,
|
||||||
|
reducing_gap: Optional[int] = None,
|
||||||
data_format: Optional[ChannelDimension] = None,
|
data_format: Optional[ChannelDimension] = None,
|
||||||
return_numpy: bool = True,
|
return_numpy: bool = True,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Resizes `image` to (h, w) specified by `size` using the PIL library.
|
Resizes `image` to `(height, width)` specified by `size` using the PIL library.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
||||||
@@ -236,8 +237,11 @@ def resize(
|
|||||||
The size to use for resizing the image.
|
The size to use for resizing the image.
|
||||||
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
The filter to user for resampling.
|
The filter to user for resampling.
|
||||||
|
reducing_gap (`int`, *optional*):
|
||||||
|
Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to
|
||||||
|
the fair resampling. See corresponding Pillow documentation for more details.
|
||||||
data_format (`ChannelDimension`, *optional*):
|
data_format (`ChannelDimension`, *optional*):
|
||||||
The channel dimension format of the output image. If `None`, will use the inferred format from the input.
|
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
||||||
return_numpy (`bool`, *optional*, defaults to `True`):
|
return_numpy (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
|
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
|
||||||
returned.
|
returned.
|
||||||
@@ -260,7 +264,7 @@ def resize(
|
|||||||
image = to_pil_image(image)
|
image = to_pil_image(image)
|
||||||
height, width = size
|
height, width = size
|
||||||
# PIL images are in the format (width, height)
|
# PIL images are in the format (width, height)
|
||||||
resized_image = image.resize((width, height), resample=resample)
|
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
|
||||||
|
|
||||||
if return_numpy:
|
if return_numpy:
|
||||||
resized_image = np.array(resized_image)
|
resized_image = np.array(resized_image)
|
||||||
@@ -290,7 +294,7 @@ def normalize(
|
|||||||
std (`float` or `Iterable[float]`):
|
std (`float` or `Iterable[float]`):
|
||||||
The standard deviation to use for normalization.
|
The standard deviation to use for normalization.
|
||||||
data_format (`ChannelDimension`, *optional*):
|
data_format (`ChannelDimension`, *optional*):
|
||||||
The channel dimension format of the output image. If `None`, will use the inferred format from the input.
|
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
||||||
"""
|
"""
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL.Image.Image):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("data2vec-vision", "BeitImageProcessor"),
|
("data2vec-vision", "BeitImageProcessor"),
|
||||||
("deit", "DeiTImageProcessor"),
|
("deit", "DeiTImageProcessor"),
|
||||||
("dinat", "ViTImageProcessor"),
|
("dinat", "ViTImageProcessor"),
|
||||||
|
("donut-swin", "DonutImageProcessor"),
|
||||||
("dpt", "DPTImageProcessor"),
|
("dpt", "DPTImageProcessor"),
|
||||||
("flava", "FlavaImageProcessor"),
|
("flava", "FlavaImageProcessor"),
|
||||||
("glpn", "GLPNImageProcessor"),
|
("glpn", "GLPNImageProcessor"),
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
_import_structure["feature_extraction_donut"] = ["DonutFeatureExtractor"]
|
_import_structure["feature_extraction_donut"] = ["DonutFeatureExtractor"]
|
||||||
|
_import_structure["image_processing_donut"] = ["DonutImageProcessor"]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -69,6 +70,7 @@ if TYPE_CHECKING:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
from .feature_extraction_donut import DonutFeatureExtractor
|
from .feature_extraction_donut import DonutFeatureExtractor
|
||||||
|
from .image_processing_donut import DonutImageProcessor
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -14,197 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Feature extractor class for Donut."""
|
"""Feature extractor class for Donut."""
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from ...utils import logging
|
||||||
|
from .image_processing_donut import DonutImageProcessor
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageOps
|
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
|
||||||
from ...image_utils import (
|
|
||||||
IMAGENET_STANDARD_MEAN,
|
|
||||||
IMAGENET_STANDARD_STD,
|
|
||||||
ImageFeatureExtractionMixin,
|
|
||||||
ImageInput,
|
|
||||||
is_torch_tensor,
|
|
||||||
)
|
|
||||||
from ...utils import TensorType, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DonutFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
DonutFeatureExtractor = DonutImageProcessor
|
||||||
r"""
|
|
||||||
Constructs a Donut feature extractor.
|
|
||||||
|
|
||||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
|
||||||
should refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
do_resize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to resize the shorter edge of the input to the minimum value of a certain `size`.
|
|
||||||
size (`Tuple(int)`, *optional*, defaults to [1920, 2560]):
|
|
||||||
Resize the shorter edge of the input to the minimum value of the given size. Should be a tuple of (width,
|
|
||||||
height). Only has an effect if `do_resize` is set to `True`.
|
|
||||||
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
|
||||||
An optional resampling filter. This can be one of `PILImageResampling.NEAREST`, `PILImageResampling.BOX`,
|
|
||||||
`PILImageResampling.BILINEAR`, `PILImageResampling.HAMMING`, `PILImageResampling.BICUBIC` or
|
|
||||||
`PILImageResampling.LANCZOS`. Only has an effect if `do_resize` is set to `True`.
|
|
||||||
do_thumbnail (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to thumbnail the input to the given `size`.
|
|
||||||
do_align_long_axis (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to rotate the input if the height is greater than width.
|
|
||||||
do_pad (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to pad the input to `size`.
|
|
||||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to normalize the input with mean and standard deviation.
|
|
||||||
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of means for each channel, to be used when normalizing images.
|
|
||||||
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
|
||||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_input_names = ["pixel_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
do_resize=True,
|
|
||||||
size=[1920, 2560],
|
|
||||||
resample=PILImageResampling.BILINEAR,
|
|
||||||
do_thumbnail=True,
|
|
||||||
do_align_long_axis=False,
|
|
||||||
do_pad=True,
|
|
||||||
do_normalize=True,
|
|
||||||
image_mean=None,
|
|
||||||
image_std=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.do_resize = do_resize
|
|
||||||
self.size = size
|
|
||||||
self.resample = resample
|
|
||||||
self.do_thumbnail = do_thumbnail
|
|
||||||
self.do_align_long_axis = do_align_long_axis
|
|
||||||
self.do_pad = do_pad
|
|
||||||
self.do_normalize = do_normalize
|
|
||||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
|
||||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
|
||||||
|
|
||||||
def rotate_image(self, image, size):
|
|
||||||
if not isinstance(image, Image.Image):
|
|
||||||
image = self.to_pil_image(image)
|
|
||||||
|
|
||||||
if (size[1] > size[0] and image.width > image.height) or (size[1] < size[0] and image.width < image.height):
|
|
||||||
image = self.rotate(image, angle=-90, expand=True)
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
def thumbnail(self, image, size):
|
|
||||||
if not isinstance(image, Image.Image):
|
|
||||||
image = self.to_pil_image(image)
|
|
||||||
|
|
||||||
image.thumbnail((size[0], size[1]))
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
def pad(self, image: Image.Image, size: Tuple[int, int], random_padding: bool = False) -> Image.Image:
|
|
||||||
delta_width = size[0] - image.width
|
|
||||||
delta_height = size[1] - image.height
|
|
||||||
|
|
||||||
if random_padding:
|
|
||||||
pad_width = np.random.randint(low=0, high=delta_width + 1)
|
|
||||||
pad_height = np.random.randint(low=0, high=delta_height + 1)
|
|
||||||
else:
|
|
||||||
pad_width = delta_width // 2
|
|
||||||
pad_height = delta_height // 2
|
|
||||||
|
|
||||||
padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
|
|
||||||
return ImageOps.expand(image, padding)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
images: ImageInput,
|
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
||||||
random_padding=False,
|
|
||||||
**kwargs
|
|
||||||
) -> BatchFeature:
|
|
||||||
"""
|
|
||||||
Main method to prepare for the model one or several image(s).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
|
||||||
PIL images.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
|
||||||
number of channels, H and W are image height and width.
|
|
||||||
|
|
||||||
random_padding (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to randomly pad the input to `size`.
|
|
||||||
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
||||||
|
|
||||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
|
||||||
width).
|
|
||||||
"""
|
|
||||||
# Input type checking for clearer error
|
|
||||||
valid_images = False
|
|
||||||
|
|
||||||
# Check that images has a valid type
|
|
||||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
|
||||||
valid_images = True
|
|
||||||
elif isinstance(images, (list, tuple)):
|
|
||||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
|
||||||
valid_images = True
|
|
||||||
|
|
||||||
if not valid_images:
|
|
||||||
raise ValueError(
|
|
||||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
|
||||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batched = bool(
|
|
||||||
isinstance(images, (list, tuple))
|
|
||||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_batched:
|
|
||||||
images = [images]
|
|
||||||
|
|
||||||
# transformations (rotating + resizing + thumbnailing + padding + normalization)
|
|
||||||
if self.do_align_long_axis:
|
|
||||||
images = [self.rotate_image(image, self.size) for image in images]
|
|
||||||
if self.do_resize and self.size is not None:
|
|
||||||
images = [
|
|
||||||
self.resize(image=image, size=min(self.size), resample=self.resample, default_to_square=False)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
if self.do_thumbnail and self.size is not None:
|
|
||||||
images = [self.thumbnail(image=image, size=self.size) for image in images]
|
|
||||||
if self.do_pad and self.size is not None:
|
|
||||||
images = [self.pad(image=image, size=self.size, random_padding=random_padding) for image in images]
|
|
||||||
if self.do_normalize:
|
|
||||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
|
||||||
|
|
||||||
# return as BatchFeature
|
|
||||||
data = {"pixel_values": images}
|
|
||||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
|
||||||
|
|
||||||
return encoded_inputs
|
|
||||||
|
|||||||
426
src/transformers/models/donut/image_processing_donut.py
Normal file
426
src/transformers/models/donut/image_processing_donut.py
Normal file
@@ -0,0 +1,426 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for Donut."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import (
|
||||||
|
get_resize_output_image_size,
|
||||||
|
normalize,
|
||||||
|
pad,
|
||||||
|
rescale,
|
||||||
|
resize,
|
||||||
|
to_channel_dimension_format,
|
||||||
|
)
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
get_image_size,
|
||||||
|
is_batched,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
)
|
||||||
|
from ...utils import TensorType, logging
|
||||||
|
from ...utils.import_utils import is_vision_available
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
class DonutImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a Donut image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||||
|
`do_resize` in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
||||||
|
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
|
||||||
|
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
|
||||||
|
method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||||
|
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
|
||||||
|
`preprocess` method.
|
||||||
|
crop_size (`Dict[str, int]` *optional*, defaults to 224):
|
||||||
|
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||||
|
the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_normalize:
|
||||||
|
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||||
|
Image standard deviation.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_thumbnail: bool = True,
|
||||||
|
do_align_long_axis: bool = False,
|
||||||
|
do_pad: bool = True,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
size = size if size is not None else {"height": 2560, "width": 1920}
|
||||||
|
if isinstance(size, (tuple, list)):
|
||||||
|
# The previous feature extractor size parameter was in (width, height) format
|
||||||
|
size = size[::-1]
|
||||||
|
size = get_size_dict(size)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_thumbnail = do_thumbnail
|
||||||
|
self.do_align_long_axis = do_align_long_axis
|
||||||
|
self.do_pad = do_pad
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
|
||||||
|
def align_long_axis(
|
||||||
|
self, image: np.ndarray, size: Dict[str, int], data_format: Optional[Union[str, ChannelDimension]] = None
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Align the long axis of the image to the longest axis of the specified size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
The image to be aligned.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
The size `{"height": h, "width": w}` to align the long axis to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`np.ndarray`: The aligned image.
|
||||||
|
"""
|
||||||
|
input_height, input_width = get_image_size(image)
|
||||||
|
output_height, output_width = size["height"], size["width"]
|
||||||
|
|
||||||
|
if (output_width < output_height and input_width > input_height) or (
|
||||||
|
output_width > output_height and input_width < input_height
|
||||||
|
):
|
||||||
|
image = np.rot90(image, 3)
|
||||||
|
|
||||||
|
if data_format is not None:
|
||||||
|
image = to_channel_dimension_format(image, data_format)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def rotate_image(self, *args, **kwargs):
|
||||||
|
logger.info(
|
||||||
|
"rotate_image is deprecated and will be removed in version 4.27. Please use align_long_axis instead."
|
||||||
|
)
|
||||||
|
return self.align_long_axis(*args, **kwargs)
|
||||||
|
|
||||||
|
def pad_image(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
random_padding: bool = False,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Pad the image to the specified size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
The image to be padded.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
The size `{"height": h, "width": w}` to pad the image to.
|
||||||
|
random_padding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use random padding or not.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The data format of the output image. If unset, the same format as the input image is used.
|
||||||
|
"""
|
||||||
|
output_height, output_width = size["height"], size["width"]
|
||||||
|
input_height, input_width = get_image_size(image)
|
||||||
|
|
||||||
|
delta_width = output_width - input_width
|
||||||
|
delta_height = output_height - input_height
|
||||||
|
|
||||||
|
if random_padding:
|
||||||
|
pad_top = np.random.randint(low=0, high=delta_height + 1)
|
||||||
|
pad_left = np.random.randint(low=0, high=delta_width + 1)
|
||||||
|
else:
|
||||||
|
pad_top = delta_height // 2
|
||||||
|
pad_left = delta_width // 2
|
||||||
|
|
||||||
|
pad_bottom = delta_height - pad_top
|
||||||
|
pad_right = delta_width - pad_left
|
||||||
|
|
||||||
|
padding = ((pad_top, pad_bottom), (pad_left, pad_right))
|
||||||
|
return pad(image, padding, data_format=data_format)
|
||||||
|
|
||||||
|
def pad(self, *args, **kwargs):
|
||||||
|
logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.")
|
||||||
|
return self.pad_image(*args, **kwargs)
|
||||||
|
|
||||||
|
def thumbnail(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize the image to the specified size using thumbnail method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
The image to be resized.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
The size `{"height": h, "width": w}` to resize the image to.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
The resampling filter to use.
|
||||||
|
data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
|
||||||
|
The data format of the output image. If unset, the same format as the input image is used.
|
||||||
|
"""
|
||||||
|
output_size = (size["height"], size["width"])
|
||||||
|
return resize(image, size=output_size, resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
||||||
|
resized to keep the input aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Size of the output image.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size)
|
||||||
|
shortest_edge = min(size["height"], size["width"])
|
||||||
|
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
|
||||||
|
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def rescale(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
scale: Union[int, float],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rescale an image by a scale factor. image = image * scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to rescale.
|
||||||
|
scale (`int` or `float`):
|
||||||
|
Scale to apply to the image.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def normalize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
mean: Union[float, List[float]],
|
||||||
|
std: Union[float, List[float]],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize an image. image = (image - image_mean) / image_std.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to normalize.
|
||||||
|
image_mean (`float` or `List[float]`):
|
||||||
|
Image mean.
|
||||||
|
image_std (`float` or `List[float]`):
|
||||||
|
Image standard deviation.
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
"""
|
||||||
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_thumbnail: bool = None,
|
||||||
|
do_align_long_axis: bool = None,
|
||||||
|
do_pad: bool = None,
|
||||||
|
random_padding: bool = False,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
**kwargs
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after resizing. Shortest edge of the image is resized to min(size["height"],
|
||||||
|
size["width"]) with the longest edge resized to keep the input aspect ratio.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`):
|
||||||
|
Whether to resize the image using thumbnail method.
|
||||||
|
do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`):
|
||||||
|
Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
|
||||||
|
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||||
|
Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random
|
||||||
|
amont of padding on each size, up to the largest image size in the batch. Otherwise, all images are
|
||||||
|
padded to the largest image size in the batch.
|
||||||
|
random_padding (`bool`, *optional*, defaults to `self.random_padding`):
|
||||||
|
Whether to use random padding when padding the image. If `True`, each image in the batch with be padded
|
||||||
|
with a random amount of padding on each side up to the size of the largest image in the batch.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image pixel values.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean to use for normalization.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation to use for normalization.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- Unset: defaults to the channel dimension format of the input image.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
if isinstance(size, (tuple, list)):
|
||||||
|
# Previous feature extractor had size in (width, height) format
|
||||||
|
size = size[::-1]
|
||||||
|
size = get_size_dict(size)
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_thumbnail = do_thumbnail if do_thumbnail is not None else self.do_thumbnail
|
||||||
|
do_align_long_axis = do_align_long_axis if do_align_long_axis is not None else self.do_align_long_axis
|
||||||
|
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
|
||||||
|
if not is_batched(images):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_resize and size is None:
|
||||||
|
raise ValueError("Size must be specified if do_resize is True.")
|
||||||
|
|
||||||
|
if do_rescale and rescale_factor is None:
|
||||||
|
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||||
|
|
||||||
|
if do_pad and size is None:
|
||||||
|
raise ValueError("Size must be specified if do_pad is True.")
|
||||||
|
|
||||||
|
if do_normalize and (image_mean is None or image_std is None):
|
||||||
|
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if do_align_long_axis:
|
||||||
|
images = [self.align_long_axis(image) for image in images]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||||
|
|
||||||
|
if do_thumbnail:
|
||||||
|
images = [self.thumbnail(image=image, size=size) for image in images]
|
||||||
|
|
||||||
|
if do_pad:
|
||||||
|
images = [self.pad(image=image, size=size, random_padding=random_padding) for image in images]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
|
data = {"pixel_values": images}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
@@ -37,12 +37,27 @@ class DonutProcessor(ProcessorMixin):
|
|||||||
tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]):
|
tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]):
|
||||||
An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input.
|
An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input.
|
||||||
"""
|
"""
|
||||||
feature_extractor_class = "AutoFeatureExtractor"
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
image_processor_class = "AutoImageProcessor"
|
||||||
tokenizer_class = "AutoTokenizer"
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
def __init__(self, feature_extractor, tokenizer):
|
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
||||||
super().__init__(feature_extractor, tokenizer)
|
if "feature_extractor" in kwargs:
|
||||||
self.current_processor = self.feature_extractor
|
warnings.warn(
|
||||||
|
"The `feature_extractor` argument is deprecated and will be removed in v4.27, use `image_processor`"
|
||||||
|
" instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
feature_extractor = kwargs.pop("feature_extractor")
|
||||||
|
|
||||||
|
image_processor = image_processor if image_processor is not None else feature_extractor
|
||||||
|
if image_processor is None:
|
||||||
|
raise ValueError("You need to specify an `image_processor`.")
|
||||||
|
if tokenizer is None:
|
||||||
|
raise ValueError("You need to specify a `tokenizer`.")
|
||||||
|
|
||||||
|
super().__init__(image_processor, tokenizer)
|
||||||
|
self.current_processor = self.image_processor
|
||||||
self._in_target_context_manager = False
|
self._in_target_context_manager = False
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
@@ -66,7 +81,7 @@ class DonutProcessor(ProcessorMixin):
|
|||||||
raise ValueError("You need to specify either an `images` or `text` input to process.")
|
raise ValueError("You need to specify either an `images` or `text` input to process.")
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
inputs = self.feature_extractor(images, *args, **kwargs)
|
inputs = self.image_processor(images, *args, **kwargs)
|
||||||
if text is not None:
|
if text is not None:
|
||||||
encodings = self.tokenizer(text, **kwargs)
|
encodings = self.tokenizer(text, **kwargs)
|
||||||
|
|
||||||
@@ -105,7 +120,7 @@ class DonutProcessor(ProcessorMixin):
|
|||||||
self._in_target_context_manager = True
|
self._in_target_context_manager = True
|
||||||
self.current_processor = self.tokenizer
|
self.current_processor = self.tokenizer
|
||||||
yield
|
yield
|
||||||
self.current_processor = self.feature_extractor
|
self.current_processor = self.image_processor
|
||||||
self._in_target_context_manager = False
|
self._in_target_context_manager = False
|
||||||
|
|
||||||
def token2json(self, tokens, is_inner_value=False, added_vocab=None):
|
def token2json(self, tokens, is_inner_value=False, added_vocab=None):
|
||||||
@@ -157,3 +172,12 @@ class DonutProcessor(ProcessorMixin):
|
|||||||
return [output] if is_inner_value else output
|
return [output] if is_inner_value else output
|
||||||
else:
|
else:
|
||||||
return [] if is_inner_value else {"text_sequence": tokens}
|
return [] if is_inner_value else {"text_sequence": tokens}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feature_extractor_class(self):
|
||||||
|
warnings.warn(
|
||||||
|
"`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`"
|
||||||
|
" instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
return self.image_processor_class
|
||||||
|
|||||||
@@ -113,6 +113,13 @@ class DonutFeatureExtractor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["vision"])
|
requires_backends(self, ["vision"])
|
||||||
|
|
||||||
|
|
||||||
|
class DonutImageProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["vision"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["vision"])
|
||||||
|
|
||||||
|
|
||||||
class DPTFeatureExtractor(metaclass=DummyObject):
|
class DPTFeatureExtractor(metaclass=DummyObject):
|
||||||
_backends = ["vision"]
|
_backends = ["vision"]
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class DonutFeatureExtractionTester(unittest.TestCase):
|
|||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
max_resolution=400,
|
max_resolution=400,
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
size=[20, 18],
|
size=None,
|
||||||
do_thumbnail=True,
|
do_thumbnail=True,
|
||||||
do_align_axis=False,
|
do_align_axis=False,
|
||||||
do_pad=True,
|
do_pad=True,
|
||||||
@@ -58,7 +58,7 @@ class DonutFeatureExtractionTester(unittest.TestCase):
|
|||||||
self.min_resolution = min_resolution
|
self.min_resolution = min_resolution
|
||||||
self.max_resolution = max_resolution
|
self.max_resolution = max_resolution
|
||||||
self.do_resize = do_resize
|
self.do_resize = do_resize
|
||||||
self.size = size
|
self.size = size if size is not None else {"height": 18, "width": 20}
|
||||||
self.do_thumbnail = do_thumbnail
|
self.do_thumbnail = do_thumbnail
|
||||||
self.do_align_axis = do_align_axis
|
self.do_align_axis = do_align_axis
|
||||||
self.do_pad = do_pad
|
self.do_pad = do_pad
|
||||||
@@ -121,8 +121,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
|||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size[1],
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size[0],
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -133,8 +133,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
|||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size[1],
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size[0],
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -153,8 +153,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
|||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size[1],
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size[0],
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -165,8 +165,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
|||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size[1],
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size[0],
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -185,8 +185,8 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
|||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size[1],
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size[0],
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -197,7 +197,7 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
|
|||||||
(
|
(
|
||||||
self.feature_extract_tester.batch_size,
|
self.feature_extract_tester.batch_size,
|
||||||
self.feature_extract_tester.num_channels,
|
self.feature_extract_tester.num_channels,
|
||||||
self.feature_extract_tester.size[1],
|
self.feature_extract_tester.size["height"],
|
||||||
self.feature_extract_tester.size[0],
|
self.feature_extract_tester.size["width"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user