Add optimized PixtralImageProcessorFast (#34836)
* Add optimized PixtralImageProcessorFast * make style * Add dummy_vision_object * Review comments * Format * Fix dummy * Format * np.ceil for math.ceil
This commit is contained in:
@@ -88,6 +88,11 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up
|
|||||||
[[autodoc]] PixtralImageProcessor
|
[[autodoc]] PixtralImageProcessor
|
||||||
- preprocess
|
- preprocess
|
||||||
|
|
||||||
|
## PixtralImageProcessorFast
|
||||||
|
|
||||||
|
[[autodoc]] PixtralImageProcessorFast
|
||||||
|
- preprocess
|
||||||
|
|
||||||
## PixtralProcessor
|
## PixtralProcessor
|
||||||
|
|
||||||
[[autodoc]] PixtralProcessor
|
[[autodoc]] PixtralProcessor
|
||||||
|
|||||||
@@ -1260,6 +1260,7 @@ else:
|
|||||||
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
|
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
|
||||||
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
|
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
|
||||||
_import_structure["models.detr"].append("DetrImageProcessorFast")
|
_import_structure["models.detr"].append("DetrImageProcessorFast")
|
||||||
|
_import_structure["models.pixtral"].append("PixtralImageProcessorFast")
|
||||||
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
|
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
|
||||||
_import_structure["models.vit"].append("ViTImageProcessorFast")
|
_import_structure["models.vit"].append("ViTImageProcessorFast")
|
||||||
|
|
||||||
@@ -6189,6 +6190,7 @@ if TYPE_CHECKING:
|
|||||||
from .image_processing_utils_fast import BaseImageProcessorFast
|
from .image_processing_utils_fast import BaseImageProcessorFast
|
||||||
from .models.deformable_detr import DeformableDetrImageProcessorFast
|
from .models.deformable_detr import DeformableDetrImageProcessorFast
|
||||||
from .models.detr import DetrImageProcessorFast
|
from .models.detr import DetrImageProcessorFast
|
||||||
|
from .models.pixtral import PixtralImageProcessorFast
|
||||||
from .models.rt_detr import RTDetrImageProcessorFast
|
from .models.rt_detr import RTDetrImageProcessorFast
|
||||||
from .models.vit import ViTImageProcessorFast
|
from .models.vit import ViTImageProcessorFast
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from packaging import version
|
|||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ExplicitEnum,
|
ExplicitEnum,
|
||||||
|
TensorType,
|
||||||
is_jax_tensor,
|
is_jax_tensor,
|
||||||
is_numpy_array,
|
is_numpy_array,
|
||||||
is_tf_tensor,
|
is_tf_tensor,
|
||||||
@@ -447,6 +448,44 @@ def validate_preprocess_arguments(
|
|||||||
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
|
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_fast_preprocess_arguments(
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_normalize: Optional[bool] = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_pad: Optional[bool] = None,
|
||||||
|
size_divisibility: Optional[int] = None,
|
||||||
|
do_center_crop: Optional[bool] = None,
|
||||||
|
crop_size: Optional[Dict[str, int]] = None,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
resample: Optional["PILImageResampling"] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
|
||||||
|
Raises `ValueError` if arguments incompatibility is caught.
|
||||||
|
"""
|
||||||
|
validate_preprocess_arguments(
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
)
|
||||||
|
# Extra checks for ImageProcessorFast
|
||||||
|
if return_tensors != "pt":
|
||||||
|
raise ValueError("Only returning PyTorch tensors is currently supported.")
|
||||||
|
|
||||||
|
if data_format != ChannelDimension.FIRST:
|
||||||
|
raise ValueError("Only channel first data format is currently supported.")
|
||||||
|
|
||||||
|
|
||||||
# In the future we can add a TF implementation here when we have TF models.
|
# In the future we can add a TF implementation here when we have TF models.
|
||||||
class ImageFeatureExtractionMixin:
|
class ImageFeatureExtractionMixin:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ else:
|
|||||||
("paligemma", ("SiglipImageProcessor",)),
|
("paligemma", ("SiglipImageProcessor",)),
|
||||||
("perceiver", ("PerceiverImageProcessor",)),
|
("perceiver", ("PerceiverImageProcessor",)),
|
||||||
("pix2struct", ("Pix2StructImageProcessor",)),
|
("pix2struct", ("Pix2StructImageProcessor",)),
|
||||||
("pixtral", ("PixtralImageProcessor",)),
|
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
||||||
("poolformer", ("PoolFormerImageProcessor",)),
|
("poolformer", ("PoolFormerImageProcessor",)),
|
||||||
("pvt", ("PvtImageProcessor",)),
|
("pvt", ("PvtImageProcessor",)),
|
||||||
("pvt_v2", ("PvtImageProcessor",)),
|
("pvt_v2", ("PvtImageProcessor",)),
|
||||||
|
|||||||
@@ -13,7 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
from ...utils import (
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
is_torch_available,
|
||||||
|
is_torchvision_available,
|
||||||
|
is_vision_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@@ -41,6 +47,14 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]
|
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torchvision_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_import_structure["image_processing_pixtral_fast"] = ["PixtralImageProcessorFast"]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_pixtral import PixtralVisionConfig
|
from .configuration_pixtral import PixtralVisionConfig
|
||||||
@@ -65,6 +79,14 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .image_processing_pixtral import PixtralImageProcessor
|
from .image_processing_pixtral import PixtralImageProcessor
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torchvision_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from .image_processing_pixtral_fast import PixtralImageProcessorFast
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for Pixtral."""
|
"""Image processor class for Pixtral."""
|
||||||
|
|
||||||
|
import math
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -179,7 +180,7 @@ def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int])
|
|||||||
|
|
||||||
|
|
||||||
def get_resize_output_image_size(
|
def get_resize_output_image_size(
|
||||||
input_image: np.ndarray,
|
input_image: ImageInput,
|
||||||
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||||
patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
@@ -189,7 +190,7 @@ def get_resize_output_image_size(
|
|||||||
size.
|
size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_image (`np.ndarray`):
|
input_image (`ImageInput`):
|
||||||
The image to resize.
|
The image to resize.
|
||||||
size (`int` or `Tuple[int, int]`):
|
size (`int` or `Tuple[int, int]`):
|
||||||
Max image size an input image can be. Must be a dictionary with the key "longest_edge".
|
Max image size an input image can be. Must be a dictionary with the key "longest_edge".
|
||||||
@@ -210,8 +211,8 @@ def get_resize_output_image_size(
|
|||||||
|
|
||||||
if ratio > 1:
|
if ratio > 1:
|
||||||
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
|
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
|
||||||
height = int(np.ceil(height / ratio))
|
height = int(math.ceil(height / ratio))
|
||||||
width = int(np.ceil(width / ratio))
|
width = int(math.ceil(width / ratio))
|
||||||
|
|
||||||
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
|
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
|
||||||
return num_height_tokens * patch_height, num_width_tokens * patch_width
|
return num_height_tokens * patch_height, num_width_tokens * patch_width
|
||||||
|
|||||||
349
src/transformers/models/pixtral/image_processing_pixtral_fast.py
Normal file
349
src/transformers/models/pixtral/image_processing_pixtral_fast.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for Pixtral."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from ...image_processing_utils import get_size_dict
|
||||||
|
from ...image_processing_utils_fast import BaseImageProcessorFast
|
||||||
|
from ...image_utils import (
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
ImageType,
|
||||||
|
PILImageResampling,
|
||||||
|
get_image_size,
|
||||||
|
get_image_type,
|
||||||
|
infer_channel_dimension_format,
|
||||||
|
validate_fast_preprocess_arguments,
|
||||||
|
validate_kwargs,
|
||||||
|
)
|
||||||
|
from ...utils import (
|
||||||
|
TensorType,
|
||||||
|
is_torch_available,
|
||||||
|
is_torchvision_available,
|
||||||
|
is_torchvision_v2_available,
|
||||||
|
is_vision_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from .image_processing_pixtral import (
|
||||||
|
BatchMixFeature,
|
||||||
|
convert_to_rgb,
|
||||||
|
get_resize_output_image_size,
|
||||||
|
make_list_of_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
if is_vision_available():
|
||||||
|
from ...image_utils import pil_torch_interpolation_mapping
|
||||||
|
|
||||||
|
if is_torchvision_v2_available():
|
||||||
|
from torchvision.transforms.v2 import functional as F
|
||||||
|
else:
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralImageProcessorFast(BaseImageProcessorFast):
|
||||||
|
r"""
|
||||||
|
Constructs a fast Pixtral image processor that leverages torchvision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||||
|
`do_resize` in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`):
|
||||||
|
Size of the maximum dimension of either the height or width dimension of the image. Used to control how
|
||||||
|
images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)`
|
||||||
|
patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`):
|
||||||
|
Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||||
|
the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to convert the image to RGB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
patch_size: Dict[str, int] = None,
|
||||||
|
resample: Union[PILImageResampling, "F.InterpolationMode"] = PILImageResampling.BICUBIC,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_convert_rgb: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"longest_edge": 1024}
|
||||||
|
patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
|
||||||
|
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
self.do_convert_rgb = do_convert_rgb
|
||||||
|
self._valid_processor_keys = [
|
||||||
|
"images",
|
||||||
|
"do_resize",
|
||||||
|
"size",
|
||||||
|
"patch_size",
|
||||||
|
"resample",
|
||||||
|
"do_rescale",
|
||||||
|
"rescale_factor",
|
||||||
|
"do_normalize",
|
||||||
|
"image_mean",
|
||||||
|
"image_std",
|
||||||
|
"do_convert_rgb",
|
||||||
|
"return_tensors",
|
||||||
|
"data_format",
|
||||||
|
"input_data_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
size: Dict[str, int],
|
||||||
|
patch_size: Dict[str, int],
|
||||||
|
interpolation: "F.InterpolationMode" = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
||||||
|
resized to keep the input aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`torch.Tensor`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Dict containing the longest possible edge of the image.
|
||||||
|
patch_size (`Dict[str, int]`):
|
||||||
|
Patch size used to calculate the size of the output image.
|
||||||
|
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||||
|
Resampling filter to use when resiizing the image.
|
||||||
|
"""
|
||||||
|
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
||||||
|
if "longest_edge" in size:
|
||||||
|
size = (size["longest_edge"], size["longest_edge"])
|
||||||
|
elif "height" in size and "width" in size:
|
||||||
|
size = (size["height"], size["width"])
|
||||||
|
else:
|
||||||
|
raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.")
|
||||||
|
|
||||||
|
if "height" in patch_size and "width" in patch_size:
|
||||||
|
patch_size = (patch_size["height"], patch_size["width"])
|
||||||
|
else:
|
||||||
|
raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.")
|
||||||
|
|
||||||
|
output_size = get_resize_output_image_size(
|
||||||
|
image,
|
||||||
|
size=size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
)
|
||||||
|
return F.resize(
|
||||||
|
image,
|
||||||
|
size=output_size,
|
||||||
|
interpolation=interpolation,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
patch_size: Dict[str, int] = None,
|
||||||
|
resample: Optional[Union[PILImageResampling, "F.InterpolationMode"]] = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_convert_rgb: bool = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchMixFeature:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||||
|
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Describes the maximum input dimensions to the model.
|
||||||
|
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
|
||||||
|
Patch size in the model. Used to calculate the image after resizing.
|
||||||
|
resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||||
|
`True`.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||||
|
Whether to convert the image to RGB.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- Unset: Use the channel dimension format of the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||||
|
from the input image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||||
|
"""
|
||||||
|
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||||
|
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||||
|
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||||
|
device = kwargs.pop("device", None)
|
||||||
|
|
||||||
|
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
||||||
|
|
||||||
|
images_list = make_list_of_images(images)
|
||||||
|
image_type = get_image_type(images_list[0][0])
|
||||||
|
|
||||||
|
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
|
||||||
|
raise ValueError(f"Unsupported input image type {image_type}")
|
||||||
|
|
||||||
|
validate_fast_preprocess_arguments(
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
data_format=data_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_convert_rgb:
|
||||||
|
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
|
||||||
|
|
||||||
|
if image_type == ImageType.PIL:
|
||||||
|
images_list = [[F.pil_to_tensor(image) for image in images] for images in images_list]
|
||||||
|
elif image_type == ImageType.NUMPY:
|
||||||
|
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
|
||||||
|
images_list = [[torch.from_numpy(image).contiguous() for image in images] for images in images_list]
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
images_list = [[image.to(device) for image in images] for images in images_list]
|
||||||
|
|
||||||
|
# We assume that all images have the same channel dimension format.
|
||||||
|
if input_data_format is None:
|
||||||
|
input_data_format = infer_channel_dimension_format(images_list[0][0])
|
||||||
|
if input_data_format == ChannelDimension.LAST:
|
||||||
|
images_list = [[image.permute(2, 0, 1).contiguous() for image in images] for images in images_list]
|
||||||
|
input_data_format = ChannelDimension.FIRST
|
||||||
|
|
||||||
|
if do_rescale and do_normalize:
|
||||||
|
# fused rescale and normalize
|
||||||
|
new_mean = torch.tensor(image_mean, device=images_list[0][0].device) * (1.0 / rescale_factor)
|
||||||
|
new_std = torch.tensor(image_std, device=images_list[0][0].device) * (1.0 / rescale_factor)
|
||||||
|
|
||||||
|
batch_images = []
|
||||||
|
batch_image_sizes = []
|
||||||
|
for sample_images in images_list:
|
||||||
|
images = []
|
||||||
|
image_sizes = []
|
||||||
|
for image in sample_images:
|
||||||
|
if do_resize:
|
||||||
|
interpolation = (
|
||||||
|
pil_torch_interpolation_mapping[resample]
|
||||||
|
if isinstance(resample, (PILImageResampling, int))
|
||||||
|
else resample
|
||||||
|
)
|
||||||
|
image = self.resize(
|
||||||
|
image=image,
|
||||||
|
size=size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
interpolation=interpolation,
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_rescale and do_normalize:
|
||||||
|
# fused rescale and normalize
|
||||||
|
image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std)
|
||||||
|
elif do_rescale:
|
||||||
|
image = image * rescale_factor
|
||||||
|
elif do_normalize:
|
||||||
|
image = F.normalize(image, image_mean, image_std)
|
||||||
|
|
||||||
|
images.append(image)
|
||||||
|
image_sizes.append(get_image_size(image, input_data_format))
|
||||||
|
batch_images.append(images)
|
||||||
|
batch_image_sizes.append(image_sizes)
|
||||||
|
|
||||||
|
return BatchMixFeature(data={"pixel_values": batch_images, "image_sizes": batch_image_sizes}, tensor_type=None)
|
||||||
@@ -23,6 +23,13 @@ class DetrImageProcessorFast(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torchvision"])
|
requires_backends(self, ["torchvision"])
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralImageProcessorFast(metaclass=DummyObject):
|
||||||
|
_backends = ["torchvision"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torchvision"])
|
||||||
|
|
||||||
|
|
||||||
class RTDetrImageProcessorFast(metaclass=DummyObject):
|
class RTDetrImageProcessorFast(metaclass=DummyObject):
|
||||||
_backends = ["torchvision"]
|
_backends = ["torchvision"]
|
||||||
|
|
||||||
|
|||||||
@@ -14,12 +14,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||||
|
|
||||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||||
|
|
||||||
@@ -32,6 +34,9 @@ if is_vision_available():
|
|||||||
|
|
||||||
from transformers import PixtralImageProcessor
|
from transformers import PixtralImageProcessor
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
from transformers import PixtralImageProcessorFast
|
||||||
|
|
||||||
|
|
||||||
class PixtralImageProcessingTester(unittest.TestCase):
|
class PixtralImageProcessingTester(unittest.TestCase):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -51,6 +56,7 @@ class PixtralImageProcessingTester(unittest.TestCase):
|
|||||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||||
do_convert_rgb=True,
|
do_convert_rgb=True,
|
||||||
):
|
):
|
||||||
|
super().__init__()
|
||||||
size = size if size is not None else {"longest_edge": 24}
|
size = size if size is not None else {"longest_edge": 24}
|
||||||
patch_size = patch_size if patch_size is not None else {"height": 8, "width": 8}
|
patch_size = patch_size if patch_size is not None else {"height": 8, "width": 8}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
@@ -128,6 +134,7 @@ class PixtralImageProcessingTester(unittest.TestCase):
|
|||||||
@require_vision
|
@require_vision
|
||||||
class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||||
image_processing_class = PixtralImageProcessor if is_vision_available() else None
|
image_processing_class = PixtralImageProcessor if is_vision_available() else None
|
||||||
|
fast_image_processing_class = PixtralImageProcessorFast if is_torchvision_available() else None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
@@ -138,7 +145,8 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
return self.image_processor_tester.prepare_image_processor_dict()
|
return self.image_processor_tester.prepare_image_processor_dict()
|
||||||
|
|
||||||
def test_image_processor_properties(self):
|
def test_image_processor_properties(self):
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
for image_processing_class in self.image_processor_list:
|
||||||
|
image_processing = image_processing_class(**self.image_processor_dict)
|
||||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||||
self.assertTrue(hasattr(image_processing, "size"))
|
self.assertTrue(hasattr(image_processing, "size"))
|
||||||
self.assertTrue(hasattr(image_processing, "patch_size"))
|
self.assertTrue(hasattr(image_processing, "patch_size"))
|
||||||
@@ -150,8 +158,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||||
|
|
||||||
def test_call_pil(self):
|
def test_call_pil(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
# Initialize image_processing
|
# Initialize image_processing
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
image_processing = image_processing_class(**self.image_processor_dict)
|
||||||
# create random PIL images
|
# create random PIL images
|
||||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs()
|
image_inputs_list = self.image_processor_tester.prepare_image_inputs()
|
||||||
for image_inputs in image_inputs_list:
|
for image_inputs in image_inputs_list:
|
||||||
@@ -160,7 +169,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# Test not batched input
|
# Test not batched input
|
||||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
|
||||||
|
image_inputs_list[0][0]
|
||||||
|
)
|
||||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||||
|
|
||||||
# Test batched
|
# Test batched
|
||||||
@@ -171,8 +182,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||||
|
|
||||||
def test_call_numpy(self):
|
def test_call_numpy(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
# Initialize image_processing
|
# Initialize image_processing
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
image_processing = image_processing_class(**self.image_processor_dict)
|
||||||
# create random numpy tensors
|
# create random numpy tensors
|
||||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
|
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
|
||||||
for image_inputs in image_inputs_list:
|
for image_inputs in image_inputs_list:
|
||||||
@@ -181,7 +193,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# Test not batched input
|
# Test not batched input
|
||||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
|
||||||
|
image_inputs_list[0][0]
|
||||||
|
)
|
||||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||||
|
|
||||||
# Test batched
|
# Test batched
|
||||||
@@ -192,8 +206,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||||
|
|
||||||
def test_call_pytorch(self):
|
def test_call_pytorch(self):
|
||||||
|
for image_processing_class in self.image_processor_list:
|
||||||
# Initialize image_processing
|
# Initialize image_processing
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
image_processing = image_processing_class(**self.image_processor_dict)
|
||||||
# create random PyTorch tensors
|
# create random PyTorch tensors
|
||||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
|
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
|
||||||
for image_inputs in image_inputs_list:
|
for image_inputs in image_inputs_list:
|
||||||
@@ -202,7 +217,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# Test not batched input
|
# Test not batched input
|
||||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
|
||||||
|
image_inputs_list[0][0]
|
||||||
|
)
|
||||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||||
|
|
||||||
# Test batched
|
# Test batched
|
||||||
@@ -212,6 +229,50 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@require_torch
|
||||||
|
def test_fast_is_faster_than_slow(self):
|
||||||
|
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||||
|
self.skipTest(reason="Skipping speed test")
|
||||||
|
|
||||||
|
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||||
|
self.skipTest(reason="Skipping speed test as one of the image processors is not defined")
|
||||||
|
|
||||||
|
def measure_time(image_processor, image):
|
||||||
|
start = time.time()
|
||||||
|
_ = image_processor(image, return_tensors="pt")
|
||||||
|
return time.time() - start
|
||||||
|
|
||||||
|
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
|
||||||
|
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||||
|
|
||||||
|
fast_time = measure_time(image_processor_fast, image_inputs_list)
|
||||||
|
slow_time = measure_time(image_processor_slow, image_inputs_list)
|
||||||
|
|
||||||
|
self.assertLessEqual(fast_time, slow_time)
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@require_torch
|
||||||
|
def test_slow_fast_equivalence(self):
|
||||||
|
dummy_image = Image.open(
|
||||||
|
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||||
|
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||||
|
|
||||||
|
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||||
|
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||||
|
|
||||||
|
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||||
|
|
||||||
|
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||||
|
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], atol=1e-2))
|
||||||
|
|
||||||
@unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
@unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||||
def test_call_numpy_4_channels(self):
|
def test_call_numpy_4_channels(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user