Add optimized PixtralImageProcessorFast (#34836)
* Add optimized PixtralImageProcessorFast * make style * Add dummy_vision_object * Review comments * Format * Fix dummy * Format * np.ceil for math.ceil
This commit is contained in:
@@ -88,6 +88,11 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up
|
||||
[[autodoc]] PixtralImageProcessor
|
||||
- preprocess
|
||||
|
||||
## PixtralImageProcessorFast
|
||||
|
||||
[[autodoc]] PixtralImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## PixtralProcessor
|
||||
|
||||
[[autodoc]] PixtralProcessor
|
||||
|
||||
@@ -1260,6 +1260,7 @@ else:
|
||||
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
|
||||
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
|
||||
_import_structure["models.detr"].append("DetrImageProcessorFast")
|
||||
_import_structure["models.pixtral"].append("PixtralImageProcessorFast")
|
||||
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
|
||||
_import_structure["models.vit"].append("ViTImageProcessorFast")
|
||||
|
||||
@@ -6189,6 +6190,7 @@ if TYPE_CHECKING:
|
||||
from .image_processing_utils_fast import BaseImageProcessorFast
|
||||
from .models.deformable_detr import DeformableDetrImageProcessorFast
|
||||
from .models.detr import DetrImageProcessorFast
|
||||
from .models.pixtral import PixtralImageProcessorFast
|
||||
from .models.rt_detr import RTDetrImageProcessorFast
|
||||
from .models.vit import ViTImageProcessorFast
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from packaging import version
|
||||
|
||||
from .utils import (
|
||||
ExplicitEnum,
|
||||
TensorType,
|
||||
is_jax_tensor,
|
||||
is_numpy_array,
|
||||
is_tf_tensor,
|
||||
@@ -447,6 +448,44 @@ def validate_preprocess_arguments(
|
||||
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
|
||||
|
||||
|
||||
def validate_fast_preprocess_arguments(
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_pad: Optional[bool] = None,
|
||||
size_divisibility: Optional[int] = None,
|
||||
do_center_crop: Optional[bool] = None,
|
||||
crop_size: Optional[Dict[str, int]] = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: Optional["PILImageResampling"] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
):
|
||||
"""
|
||||
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
|
||||
Raises `ValueError` if arguments incompatibility is caught.
|
||||
"""
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
# Extra checks for ImageProcessorFast
|
||||
if return_tensors != "pt":
|
||||
raise ValueError("Only returning PyTorch tensors is currently supported.")
|
||||
|
||||
if data_format != ChannelDimension.FIRST:
|
||||
raise ValueError("Only channel first data format is currently supported.")
|
||||
|
||||
|
||||
# In the future we can add a TF implementation here when we have TF models.
|
||||
class ImageFeatureExtractionMixin:
|
||||
"""
|
||||
|
||||
@@ -117,7 +117,7 @@ else:
|
||||
("paligemma", ("SiglipImageProcessor",)),
|
||||
("perceiver", ("PerceiverImageProcessor",)),
|
||||
("pix2struct", ("Pix2StructImageProcessor",)),
|
||||
("pixtral", ("PixtralImageProcessor",)),
|
||||
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
||||
("poolformer", ("PoolFormerImageProcessor",)),
|
||||
("pvt", ("PvtImageProcessor",)),
|
||||
("pvt_v2", ("PvtImageProcessor",)),
|
||||
|
||||
@@ -13,7 +13,13 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@@ -41,6 +47,14 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]
|
||||
|
||||
try:
|
||||
if not is_torchvision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["image_processing_pixtral_fast"] = ["PixtralImageProcessorFast"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_pixtral import PixtralVisionConfig
|
||||
@@ -65,6 +79,14 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .image_processing_pixtral import PixtralImageProcessor
|
||||
|
||||
try:
|
||||
if not is_torchvision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .image_processing_pixtral_fast import PixtralImageProcessorFast
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Image processor class for Pixtral."""
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -179,7 +180,7 @@ def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int])
|
||||
|
||||
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray,
|
||||
input_image: ImageInput,
|
||||
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||
patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
@@ -189,7 +190,7 @@ def get_resize_output_image_size(
|
||||
size.
|
||||
|
||||
Args:
|
||||
input_image (`np.ndarray`):
|
||||
input_image (`ImageInput`):
|
||||
The image to resize.
|
||||
size (`int` or `Tuple[int, int]`):
|
||||
Max image size an input image can be. Must be a dictionary with the key "longest_edge".
|
||||
@@ -210,8 +211,8 @@ def get_resize_output_image_size(
|
||||
|
||||
if ratio > 1:
|
||||
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
|
||||
height = int(np.ceil(height / ratio))
|
||||
width = int(np.ceil(width / ratio))
|
||||
height = int(math.ceil(height / ratio))
|
||||
width = int(math.ceil(width / ratio))
|
||||
|
||||
num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
|
||||
return num_height_tokens * patch_height, num_width_tokens * patch_width
|
||||
|
||||
349
src/transformers/models/pixtral/image_processing_pixtral_fast.py
Normal file
349
src/transformers/models/pixtral/image_processing_pixtral_fast.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for Pixtral."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ...image_processing_utils import get_size_dict
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
ImageType,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
get_image_type,
|
||||
infer_channel_dimension_format,
|
||||
validate_fast_preprocess_arguments,
|
||||
validate_kwargs,
|
||||
)
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
from .image_processing_pixtral import (
|
||||
BatchMixFeature,
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
make_list_of_images,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_vision_available():
|
||||
from ...image_utils import pil_torch_interpolation_mapping
|
||||
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class PixtralImageProcessorFast(BaseImageProcessorFast):
|
||||
r"""
|
||||
Constructs a fast Pixtral image processor that leverages torchvision.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||
`do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`):
|
||||
Size of the maximum dimension of either the height or width dimension of the image. Used to control how
|
||||
images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)`
|
||||
patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`):
|
||||
Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
patch_size: Dict[str, int] = None,
|
||||
resample: Union[PILImageResampling, "F.InterpolationMode"] = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"longest_edge": 1024}
|
||||
patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
|
||||
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.patch_size = patch_size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
||||
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self._valid_processor_keys = [
|
||||
"images",
|
||||
"do_resize",
|
||||
"size",
|
||||
"patch_size",
|
||||
"resample",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_convert_rgb",
|
||||
"return_tensors",
|
||||
"data_format",
|
||||
"input_data_format",
|
||||
]
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
size: Dict[str, int],
|
||||
patch_size: Dict[str, int],
|
||||
interpolation: "F.InterpolationMode" = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
|
||||
resized to keep the input aspect ratio.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dict containing the longest possible edge of the image.
|
||||
patch_size (`Dict[str, int]`):
|
||||
Patch size used to calculate the size of the output image.
|
||||
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
Resampling filter to use when resiizing the image.
|
||||
"""
|
||||
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
||||
if "longest_edge" in size:
|
||||
size = (size["longest_edge"], size["longest_edge"])
|
||||
elif "height" in size and "width" in size:
|
||||
size = (size["height"], size["width"])
|
||||
else:
|
||||
raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.")
|
||||
|
||||
if "height" in patch_size and "width" in patch_size:
|
||||
patch_size = (patch_size["height"], patch_size["width"])
|
||||
else:
|
||||
raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.")
|
||||
|
||||
output_size = get_resize_output_image_size(
|
||||
image,
|
||||
size=size,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
return F.resize(
|
||||
image,
|
||||
size=output_size,
|
||||
interpolation=interpolation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
patch_size: Dict[str, int] = None,
|
||||
resample: Optional[Union[PILImageResampling, "F.InterpolationMode"]] = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchMixFeature:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Describes the maximum input dimensions to the model.
|
||||
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
|
||||
Patch size in the model. Used to calculate the image after resizing.
|
||||
resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
device = kwargs.pop("device", None)
|
||||
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
||||
|
||||
images_list = make_list_of_images(images)
|
||||
image_type = get_image_type(images_list[0][0])
|
||||
|
||||
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
|
||||
raise ValueError(f"Unsupported input image type {image_type}")
|
||||
|
||||
validate_fast_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
return_tensors=return_tensors,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
if do_convert_rgb:
|
||||
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
|
||||
|
||||
if image_type == ImageType.PIL:
|
||||
images_list = [[F.pil_to_tensor(image) for image in images] for images in images_list]
|
||||
elif image_type == ImageType.NUMPY:
|
||||
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
|
||||
images_list = [[torch.from_numpy(image).contiguous() for image in images] for images in images_list]
|
||||
|
||||
if device is not None:
|
||||
images_list = [[image.to(device) for image in images] for images in images_list]
|
||||
|
||||
# We assume that all images have the same channel dimension format.
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images_list[0][0])
|
||||
if input_data_format == ChannelDimension.LAST:
|
||||
images_list = [[image.permute(2, 0, 1).contiguous() for image in images] for images in images_list]
|
||||
input_data_format = ChannelDimension.FIRST
|
||||
|
||||
if do_rescale and do_normalize:
|
||||
# fused rescale and normalize
|
||||
new_mean = torch.tensor(image_mean, device=images_list[0][0].device) * (1.0 / rescale_factor)
|
||||
new_std = torch.tensor(image_std, device=images_list[0][0].device) * (1.0 / rescale_factor)
|
||||
|
||||
batch_images = []
|
||||
batch_image_sizes = []
|
||||
for sample_images in images_list:
|
||||
images = []
|
||||
image_sizes = []
|
||||
for image in sample_images:
|
||||
if do_resize:
|
||||
interpolation = (
|
||||
pil_torch_interpolation_mapping[resample]
|
||||
if isinstance(resample, (PILImageResampling, int))
|
||||
else resample
|
||||
)
|
||||
image = self.resize(
|
||||
image=image,
|
||||
size=size,
|
||||
patch_size=patch_size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
|
||||
if do_rescale and do_normalize:
|
||||
# fused rescale and normalize
|
||||
image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std)
|
||||
elif do_rescale:
|
||||
image = image * rescale_factor
|
||||
elif do_normalize:
|
||||
image = F.normalize(image, image_mean, image_std)
|
||||
|
||||
images.append(image)
|
||||
image_sizes.append(get_image_size(image, input_data_format))
|
||||
batch_images.append(images)
|
||||
batch_image_sizes.append(image_sizes)
|
||||
|
||||
return BatchMixFeature(data={"pixel_values": batch_images, "image_sizes": batch_image_sizes}, tensor_type=None)
|
||||
@@ -23,6 +23,13 @@ class DetrImageProcessorFast(metaclass=DummyObject):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class PixtralImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class RTDetrImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
|
||||
@@ -14,12 +14,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
@@ -32,6 +34,9 @@ if is_vision_available():
|
||||
|
||||
from transformers import PixtralImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import PixtralImageProcessorFast
|
||||
|
||||
|
||||
class PixtralImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
@@ -51,6 +56,7 @@ class PixtralImageProcessingTester(unittest.TestCase):
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
super().__init__()
|
||||
size = size if size is not None else {"longest_edge": 24}
|
||||
patch_size = patch_size if patch_size is not None else {"height": 8, "width": 8}
|
||||
self.parent = parent
|
||||
@@ -128,6 +134,7 @@ class PixtralImageProcessingTester(unittest.TestCase):
|
||||
@require_vision
|
||||
class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = PixtralImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = PixtralImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@@ -138,7 +145,8 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "patch_size"))
|
||||
@@ -150,8 +158,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
|
||||
def test_call_pil(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs()
|
||||
for image_inputs in image_inputs_list:
|
||||
@@ -160,7 +169,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
|
||||
image_inputs_list[0][0]
|
||||
)
|
||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
@@ -171,8 +182,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
|
||||
for image_inputs in image_inputs_list:
|
||||
@@ -181,7 +193,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
|
||||
image_inputs_list[0][0]
|
||||
)
|
||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
@@ -192,8 +206,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
|
||||
for image_inputs in image_inputs_list:
|
||||
@@ -202,7 +217,9 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
|
||||
image_inputs_list[0][0]
|
||||
)
|
||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
@@ -212,6 +229,50 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_fast_is_faster_than_slow(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping speed test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping speed test as one of the image processors is not defined")
|
||||
|
||||
def measure_time(image_processor, image):
|
||||
start = time.time()
|
||||
_ = image_processor(image, return_tensors="pt")
|
||||
return time.time() - start
|
||||
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
fast_time = measure_time(image_processor_fast, image_inputs_list)
|
||||
slow_time = measure_time(image_processor_slow, image_inputs_list)
|
||||
|
||||
self.assertLessEqual(fast_time, slow_time)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence(self):
|
||||
dummy_image = Image.open(
|
||||
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
|
||||
)
|
||||
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], atol=1e-2))
|
||||
|
||||
@unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user