From a6b77598805f4e3c24a47767d503dc6ea20d1381 Mon Sep 17 00:00:00 2001
From: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 2 Nov 2022 11:57:36 +0000
Subject: [PATCH] Add Image Processors (#19796)
* Add CLIP image processor
* Crop size as dict too
* Update warning
* Actually use logger this time
* Normalize doesn't change dtype of input
* Add perceiver image processor
* Tidy up
* Add DPT image processor
* Add Vilt image processor
* Tidy up
* Add poolformer image processor
* Tidy up
* Add LayoutLM v2 and v3 imsge processors
* Tidy up
* Add Flava image processor
* Tidy up
* Add deit image processor
* Tidy up
* Add ConvNext image processor
* Tidy up
* Add levit image processor
* Add segformer image processor
* Add in post processing
* Fix up
* Add ImageGPT image processor
* Fixup
* Add mobilevit image processor
* Tidy up
* Add postprocessing
* Fixup
* Add VideoMAE image processor
* Tidy up
* Add ImageGPT image processor
* Fixup
* Add ViT image processor
* Tidy up
* Add beit image processor
* Add mobilevit image processor
* Tidy up
* Add postprocessing
* Fixup
* Fix up
* Fix flava and remove tree module
* Fix image classification pipeline failing tests
* Update feature extractor in trainer scripts
* Update pad_if_smaller to accept tuple and int size
* Update for image segmentation pipeline
* Update src/transformers/models/perceiver/image_processing_perceiver.py
Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com>
* Update src/transformers/image_processing_utils.py
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
* Update src/transformers/models/beit/image_processing_beit.py
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
* PR comments - docstrings; remove accidentally added resize; var names
* Update docstrings
* Add exception if size is not in the right format
* Fix exception check
* Fix up
* Use shortest_edge in tuple in script
Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
---
docs/source/en/preprocessing.mdx | 9 +-
docs/source/en/tasks/image_classification.mdx | 7 +-
.../run_image_classification.py | 10 +-
.../run_image_classification_no_trainer.py | 10 +-
examples/pytorch/image-pretraining/run_mae.py | 6 +-
.../run_semantic_segmentation.py | 26 +-
.../run_semantic_segmentation_no_trainer.py | 9 +-
src/transformers/image_processing_utils.py | 67 ++
src/transformers/image_transforms.py | 9 +
src/transformers/image_utils.py | 10 +-
.../models/beit/feature_extraction_beit.py | 254 +------
.../models/beit/image_processing_beit.py | 525 +++++++++++++
.../models/clip/feature_extraction_clip.py | 150 +---
.../models/clip/image_processing_clip.py | 342 +++++++++
.../convnext/feature_extraction_convnext.py | 152 +---
.../convnext/image_processing_convnext.py | 310 ++++++++
..._original_pytorch_checkpoint_to_pytorch.py | 2 +-
.../models/deit/feature_extraction_deit.py | 146 +---
.../models/deit/image_processing_deit.py | 315 ++++++++
.../models/dpt/feature_extraction_dpt.py | 230 +-----
.../models/dpt/image_processing_dpt.py | 384 ++++++++++
.../models/flava/feature_extraction_flava.py | 340 +--------
.../models/flava/image_processing_flava.py | 696 ++++++++++++++++++
.../models/glpn/image_processing_glpn.py | 22 +-
.../imagegpt/feature_extraction_imagegpt.py | 163 +---
.../imagegpt/image_processing_imagegpt.py | 239 ++++++
.../models/imagegpt/modeling_imagegpt.py | 5 +-
.../feature_extraction_layoutlmv2.py | 222 +-----
.../layoutlmv2/image_processing_layoutlmv2.py | 268 +++++++
.../feature_extraction_layoutlmv3.py | 230 +-----
.../layoutlmv3/image_processing_layoutlmv3.py | 371 ++++++++++
.../models/levit/feature_extraction_levit.py | 144 +---
.../models/levit/image_processing_levit.py | 342 +++++++++
.../mobilevit/feature_extraction_mobilevit.py | 184 +----
.../mobilevit/image_processing_mobilevit.py | 364 +++++++++
.../perceiver/feature_extraction_perceiver.py | 174 +----
.../perceiver/image_processing_perceiver.py | 330 +++++++++
.../feature_extraction_poolformer.py | 156 +---
.../poolformer/image_processing_poolformer.py | 382 ++++++++++
.../segformer/feature_extraction_segformer.py | 243 +-----
.../segformer/image_processing_segformer.py | 488 ++++++++++++
.../videomae/feature_extraction_videomae.py | 154 +---
.../videomae/image_processing_videomae.py | 380 ++++++++++
.../models/vilt/feature_extraction_vilt.py | 277 +------
.../models/vilt/image_processing_vilt.py | 487 ++++++++++++
.../models/vit/feature_extraction_vit.py | 135 +---
.../models/vit/image_processing_vit.py | 275 +++++++
.../beit/test_feature_extraction_beit.py | 68 +-
.../clip/test_feature_extraction_clip.py | 38 +-
.../test_feature_extraction_convnext.py | 27 +-
.../deit/test_feature_extraction_deit.py | 31 +-
.../models/dpt/test_feature_extraction_dpt.py | 27 +-
.../flava/test_feature_extraction_flava.py | 41 +-
tests/models/flava/test_processor_flava.py | 3 +-
.../test_feature_extraction_imagegpt.py | 3 +-
.../test_feature_extraction_layoutlmv2.py | 37 +-
.../test_feature_extraction_layoutlmv3.py | 27 +-
.../levit/test_feature_extraction_levit.py | 31 +-
.../test_feature_extraction_mobilevit.py | 30 +-
.../test_feature_extraction_poolformer.py | 31 +-
.../test_feature_extraction_segformer.py | 67 +-
.../test_feature_extraction_videomae.py | 33 +-
.../vilt/test_feature_extraction_vilt.py | 14 +-
.../models/vit/test_feature_extraction_vit.py | 27 +-
tests/utils/test_image_processing_utils.py | 71 ++
65 files changed, 7060 insertions(+), 3590 deletions(-)
create mode 100644 src/transformers/models/beit/image_processing_beit.py
create mode 100644 src/transformers/models/clip/image_processing_clip.py
create mode 100644 src/transformers/models/convnext/image_processing_convnext.py
create mode 100644 src/transformers/models/deit/image_processing_deit.py
create mode 100644 src/transformers/models/dpt/image_processing_dpt.py
create mode 100644 src/transformers/models/flava/image_processing_flava.py
create mode 100644 src/transformers/models/imagegpt/image_processing_imagegpt.py
create mode 100644 src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py
create mode 100644 src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py
create mode 100644 src/transformers/models/levit/image_processing_levit.py
create mode 100644 src/transformers/models/mobilevit/image_processing_mobilevit.py
create mode 100644 src/transformers/models/perceiver/image_processing_perceiver.py
create mode 100644 src/transformers/models/poolformer/image_processing_poolformer.py
create mode 100644 src/transformers/models/segformer/image_processing_segformer.py
create mode 100644 src/transformers/models/videomae/image_processing_videomae.py
create mode 100644 src/transformers/models/vilt/image_processing_vilt.py
create mode 100644 src/transformers/models/vit/image_processing_vit.py
create mode 100644 tests/utils/test_image_processing_utils.py
diff --git a/docs/source/en/preprocessing.mdx b/docs/source/en/preprocessing.mdx
index de0eed152d..541885c452 100644
--- a/docs/source/en/preprocessing.mdx
+++ b/docs/source/en/preprocessing.mdx
@@ -361,9 +361,12 @@ For computer vision tasks, it is common to add some type of data augmentation to
>>> from torchvision.transforms import Compose, Normalize, RandomResizedCrop, ColorJitter, ToTensor
>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
->>> _transforms = Compose(
-... [RandomResizedCrop(feature_extractor.size), ColorJitter(brightness=0.5, hue=0.5), ToTensor(), normalize]
+>>> size = (
+... feature_extractor.size["shortest_edge"]
+... if "shortest_edge" in feature_extractor.size
+... else (feature_extractor.size["height"], feature_extractor.size["width"])
... )
+>>> _transforms = Compose([RandomResizedCrop(size), ColorJitter(brightness=0.5, hue=0.5), ToTensor(), normalize])
```
2. The model accepts [`pixel_values`](model_doc/visionencoderdecoder#transformers.VisionEncoderDecoderModel.forward.pixel_values) as its input, which is generated by the feature extractor. Create a function that generates `pixel_values` from the transforms:
@@ -487,4 +490,4 @@ Load a processor with [`AutoProcessor.from_pretrained`]:
>>> prepare_dataset(lj_speech[0])
```
-The processor has now added `input_values` and `labels`, and the sampling rate has also been correctly downsampled to 16kHz. You can pass your processed dataset to the model now!
\ No newline at end of file
+The processor has now added `input_values` and `labels`, and the sampling rate has also been correctly downsampled to 16kHz. You can pass your processed dataset to the model now!
diff --git a/docs/source/en/tasks/image_classification.mdx b/docs/source/en/tasks/image_classification.mdx
index f5606cee6a..d2d86df062 100644
--- a/docs/source/en/tasks/image_classification.mdx
+++ b/docs/source/en/tasks/image_classification.mdx
@@ -83,7 +83,12 @@ Apply several image transformations to the dataset to make the model more robust
>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
->>> _transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
+>>> size = (
+... feature_extractor.size["shortest_edge"]
+... if "shortest_edge" in feature_extractor.size
+... else (feature_extractor.size["height"], feature_extractor.size["width"])
+... )
+>>> _transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
```
Create a preprocessing function that will apply the transforms and return the `pixel_values` - the inputs to the model - of the image:
diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py
index 66eb6d46e5..99ff422481 100644
--- a/examples/pytorch/image-classification/run_image_classification.py
+++ b/examples/pytorch/image-classification/run_image_classification.py
@@ -291,10 +291,14 @@ def main():
)
# Define torchvision transforms to be applied to each image.
+ if "shortest_edge" in feature_extractor.size:
+ size = feature_extractor.size["shortest_edge"]
+ else:
+ size = (feature_extractor.size["height"], feature_extractor.size["width"])
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_train_transforms = Compose(
[
- RandomResizedCrop(feature_extractor.size),
+ RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
@@ -302,8 +306,8 @@ def main():
)
_val_transforms = Compose(
[
- Resize(feature_extractor.size),
- CenterCrop(feature_extractor.size),
+ Resize(size),
+ CenterCrop(size),
ToTensor(),
normalize,
]
diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py
index 32d9d7345d..4ce684d981 100644
--- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py
+++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py
@@ -315,10 +315,14 @@ def main():
# Preprocessing the datasets
# Define torchvision transforms to be applied to each image.
+ if "shortest_edge" in feature_extractor.size:
+ size = feature_extractor.size["shortest_edge"]
+ else:
+ size = (feature_extractor.size["height"], feature_extractor.size["width"])
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = Compose(
[
- RandomResizedCrop(feature_extractor.size),
+ RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
@@ -326,8 +330,8 @@ def main():
)
val_transforms = Compose(
[
- Resize(feature_extractor.size),
- CenterCrop(feature_extractor.size),
+ Resize(size),
+ CenterCrop(size),
ToTensor(),
normalize,
]
diff --git a/examples/pytorch/image-pretraining/run_mae.py b/examples/pytorch/image-pretraining/run_mae.py
index 46130a87c9..4815c80957 100644
--- a/examples/pytorch/image-pretraining/run_mae.py
+++ b/examples/pytorch/image-pretraining/run_mae.py
@@ -298,10 +298,14 @@ def main():
# transformations as done in original MAE paper
# source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py
+ if "shortest_edge" in feature_extractor.size:
+ size = feature_extractor.size["shortest_edge"]
+ else:
+ size = (feature_extractor.size["height"], feature_extractor.size["width"])
transforms = Compose(
[
Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
- RandomResizedCrop(feature_extractor.size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
+ RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
index ba1582415a..f287037e8e 100644
--- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
+++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
@@ -57,12 +57,11 @@ require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/sema
def pad_if_smaller(img, size, fill=0):
- min_size = min(img.size)
- if min_size < size:
- original_width, original_height = img.size
- pad_height = size - original_height if original_height < size else 0
- pad_width = size - original_width if original_width < size else 0
- img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
+ size = (size, size) if isinstance(size, int) else size
+ original_width, original_height = img.size
+ pad_height = size[1] - original_height if original_height < size[1] else 0
+ pad_width = size[0] - original_width if original_width < size[0] else 0
+ img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
return img
@@ -110,12 +109,12 @@ class RandomResize:
class RandomCrop:
def __init__(self, size):
- self.size = size
+ self.size = size if isinstance(size, tuple) else (size, size)
def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
- crop_params = transforms.RandomCrop.get_params(image, (self.size, self.size))
+ crop_params = transforms.RandomCrop.get_params(image, self.size)
image = functional.crop(image, *crop_params)
target = functional.crop(target, *crop_params)
return image, target
@@ -359,7 +358,7 @@ def main():
references=labels,
num_labels=len(id2label),
ignore_index=0,
- reduce_labels=feature_extractor.reduce_labels,
+ reduce_labels=feature_extractor.do_reduce_labels,
)
# add per category metrics as individual key-value pairs
per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
@@ -396,10 +395,15 @@ def main():
# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
+ if "shortest_edge" in feature_extractor.size:
+ # We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
+ size = (feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"])
+ else:
+ size = (feature_extractor.size["height"], feature_extractor.size["width"])
train_transforms = Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
- RandomCrop(size=feature_extractor.size),
+ RandomCrop(size=size),
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
@@ -411,7 +415,7 @@ def main():
val_transforms = Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
- Resize(size=(feature_extractor.size, feature_extractor.size)),
+ Resize(size=size),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
index 3c59a0d349..49d6eac687 100644
--- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
+++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
@@ -405,10 +405,15 @@ def main():
# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
+ if "shortest_edge" in feature_extractor.size:
+ # We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
+ size = (feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"])
+ else:
+ size = (feature_extractor.size["height"], feature_extractor.size["width"])
train_transforms = Compose(
[
ReduceLabels() if args.reduce_labels else Identity(),
- RandomCrop(size=feature_extractor.size),
+ RandomCrop(size=size),
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
@@ -420,7 +425,7 @@ def main():
val_transforms = Compose(
[
ReduceLabels() if args.reduce_labels else Identity(),
- Resize(size=(feature_extractor.size, feature_extractor.size)),
+ Resize(size=size),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py
index ba9d3c0962..bdd30ecc04 100644
--- a/src/transformers/image_processing_utils.py
+++ b/src/transformers/image_processing_utils.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Iterable, Optional, Union
+
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .feature_extraction_utils import FeatureExtractionMixin
from .utils import logging
@@ -48,7 +50,72 @@ class BaseImageProcessor(ImageProcessorMixin):
super().__init__(**kwargs)
def __call__(self, images, **kwargs) -> BatchFeature:
+ """Preprocess an image or a batch of images."""
return self.preprocess(images, **kwargs)
def preprocess(self, images, **kwargs) -> BatchFeature:
raise NotImplementedError("Each image processor must implement its own preprocess method")
+
+
+def get_size_dict(
+ size: Union[int, Iterable[int], Dict[str, int]] = None,
+ max_size: Optional[int] = None,
+ height_width_order: bool = True,
+ default_to_square: bool = True,
+) -> dict:
+ """
+ Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
+ compatibility with the old feature extractor configs and removes ambiguity over whether the tuple is in (height,
+ width) or (width, height) format.
+
+ - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
+ size[0]}` if `height_width_order` is `False`.
+ - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
+ - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
+ is set, it is added to the dict as `{"longest_edge": max_size}`.
+
+ Args:
+ size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
+ The `size` parameter to be cast into a size dictionary.
+ max_size (`Optional[int]`, *optional*):
+ The `max_size` parameter to be cast into a size dictionary.
+ height_width_order (`bool`, *optional*, defaults to `True`):
+ If `size` is a tuple, whether it's in (height, width) or (width, height) order.
+ default_to_square (`bool`, *optional*, defaults to `True`):
+ If `size` is an int, whether to default to a square image or not.
+ """
+ # If a dict is passed, we check if it's a valid size dict and then return it.
+ if isinstance(size, dict):
+ size_keys = set(size.keys())
+ if (
+ size_keys != set(["height", "width"])
+ and size_keys != set(["shortest_edge"])
+ and size_keys != set(["shortest_edge", "longest_edge"])
+ ):
+ raise ValueError(
+ "The size dict must contain either the keys ('height', 'width') or ('shortest_edge')"
+ f"or ('shortest_edge', 'longest_edge') but got {size_keys}"
+ )
+ return size
+
+ # By default, if size is an int we assume it represents a tuple of (size, size).
+ elif isinstance(size, int) and default_to_square:
+ if max_size is not None:
+ raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
+ size_dict = {"height": size, "width": size}
+ # In other configs, if size is an int and default_to_square is False, size represents the length of the shortest edge after resizing.
+ elif isinstance(size, int) and not default_to_square:
+ if max_size is not None:
+ size_dict = {"shortest_edge": size, "longest_edge": max_size}
+ else:
+ size_dict = {"shortest_edge": size}
+ elif isinstance(size, (tuple, list)) and height_width_order:
+ size_dict = {"height": size[0], "width": size[1]}
+ elif isinstance(size, (tuple, list)) and not height_width_order:
+ size_dict = {"height": size[1], "width": size[0]}
+
+ logger.warning(
+ "The size parameter should be a dictionary with keys ('height', 'width'), ('shortest_edge', 'longest_edge')"
+ f" or ('shortest_edge',) got {size}. Setting as {size_dict}.",
+ )
+ return size_dict
diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py
index d4826307cb..f0ec214762 100644
--- a/src/transformers/image_transforms.py
+++ b/src/transformers/image_transforms.py
@@ -139,6 +139,9 @@ def to_pil_image(
# If the channel as been moved to first dim, we put it back at the end.
image = to_channel_dimension_format(image, ChannelDimension.LAST)
+ # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
+ image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
+
# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
do_rescale = isinstance(image.flat[0], float) if do_rescale is None else do_rescale
if do_rescale:
@@ -259,6 +262,9 @@ def resize(
if return_numpy:
resized_image = np.array(resized_image)
+ # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
+ # so we need to add it back if necessary.
+ resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
resized_image = to_channel_dimension_format(resized_image, data_format)
return resized_image
@@ -303,12 +309,14 @@ def normalize(
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
else:
mean = [mean] * num_channels
+ mean = np.array(mean, dtype=image.dtype)
if isinstance(std, Iterable):
if len(std) != num_channels:
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
else:
std = [std] * num_channels
+ std = np.array(std, dtype=image.dtype)
if input_data_format == ChannelDimension.LAST:
image = (image - mean) / std
@@ -372,6 +380,7 @@ def center_crop(
orig_height, orig_width = get_image_size(image)
crop_height, crop_width = size
+ crop_height, crop_width = int(crop_height), int(crop_width)
# In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
top = (orig_height - crop_height) // 2
diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py
index 42c67a5138..13a82aebc8 100644
--- a/src/transformers/image_utils.py
+++ b/src/transformers/image_utils.py
@@ -72,7 +72,15 @@ def is_valid_image(img):
def valid_images(imgs):
- return all(is_valid_image(img) for img in imgs)
+ # If we have an list of images, make sure every image is valid
+ if isinstance(imgs, (list, tuple)):
+ for img in imgs:
+ if not valid_images(img):
+ return False
+ # If not a list of tuple, we have been given a single image or batched tensor of images
+ elif not is_valid_image(imgs):
+ return False
+ return True
def is_batched(img):
diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py
index 82635abc15..15ed5cde37 100644
--- a/src/transformers/models/beit/feature_extraction_beit.py
+++ b/src/transformers/models/beit/feature_extraction_beit.py
@@ -14,258 +14,10 @@
# limitations under the License.
"""Feature extractor class for BEiT."""
-from typing import List, Optional, Tuple, Union
+from ...utils import logging
+from .image_processing_beit import BeitImageProcessor
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import (
- IMAGENET_STANDARD_MEAN,
- IMAGENET_STANDARD_STD,
- ImageFeatureExtractionMixin,
- ImageInput,
- is_torch_tensor,
-)
-from ...utils import TensorType, is_torch_available, logging
-
-
-if is_torch_available():
- import torch
logger = logging.get_logger(__name__)
-
-class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a BEiT feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_utils.FeatureExtractionMixin`] which contains most of
- the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size (`int` or `Tuple(int)`, *optional*, defaults to 256):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
- set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_center_crop (`bool`, *optional*, defaults to `True`):
- Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
- image is padded with 0's and then center cropped.
- crop_size (`int`, *optional*, defaults to 224):
- Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with `image_mean` and `image_std`.
- image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- reduce_labels (`bool`, *optional*, defaults to `False`):
- Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
- used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
- background label will be replaced by 255.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=256,
- resample=PILImageResampling.BICUBIC,
- do_center_crop=True,
- crop_size=224,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- reduce_labels=False,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_center_crop = do_center_crop
- self.crop_size = crop_size
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
- self.reduce_labels = reduce_labels
-
- def __call__(
- self,
- images: ImageInput,
- segmentation_maps: ImageInput = None,
- return_tensors: Optional[Union[str, TensorType]] = None,
- **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
- Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- - **labels** -- Optional labels to be fed to a model (when `segmentation_maps` are provided)
- """
- # Input type checking for clearer error
- valid_images = False
- valid_segmentation_maps = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- # Check that segmentation maps has a valid type
- if segmentation_maps is not None:
- if isinstance(segmentation_maps, (Image.Image, np.ndarray)) or is_torch_tensor(segmentation_maps):
- valid_segmentation_maps = True
- elif isinstance(segmentation_maps, (list, tuple)):
- if (
- len(segmentation_maps) == 0
- or isinstance(segmentation_maps[0], (Image.Image, np.ndarray))
- or is_torch_tensor(segmentation_maps[0])
- ):
- valid_segmentation_maps = True
-
- if not valid_segmentation_maps:
- raise ValueError(
- "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single"
- " example),`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of"
- " examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
- if segmentation_maps is not None:
- segmentation_maps = [segmentation_maps]
-
- # reduce zero label if needed
- if self.reduce_labels:
- if segmentation_maps is not None:
- for idx, map in enumerate(segmentation_maps):
- if not isinstance(map, np.ndarray):
- map = np.array(map)
- # avoid using underflow conversion
- map[map == 0] = 255
- map = map - 1
- map[map == 254] = 255
- segmentation_maps[idx] = Image.fromarray(map.astype(np.uint8))
-
- # transformations (resizing + center cropping + normalization)
- if self.do_resize and self.size is not None and self.resample is not None:
- images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
- if segmentation_maps is not None:
- segmentation_maps = [
- self.resize(map, size=self.size, resample=self.resample) for map in segmentation_maps
- ]
- if self.do_center_crop and self.crop_size is not None:
- images = [self.center_crop(image, self.crop_size) for image in images]
- if segmentation_maps is not None:
- segmentation_maps = [self.center_crop(map, size=self.crop_size) for map in segmentation_maps]
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
-
- if segmentation_maps is not None:
- labels = []
- for map in segmentation_maps:
- if not isinstance(map, np.ndarray):
- map = np.array(map)
- labels.append(map.astype(np.int64))
- # cast to np.int64
- data["labels"] = labels
-
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
-
- def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
- """
- Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
-
- Args:
- outputs ([`BeitForSemanticSegmentation`]):
- Raw outputs of the model.
- target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
- List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
- None, predictions will not be resized.
- Returns:
- semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
- segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
- specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
- """
- logits = outputs.logits
-
- # Resize logits and compute semantic segmentation maps
- if target_sizes is not None:
- if len(logits) != len(target_sizes):
- raise ValueError(
- "Make sure that you pass in as many target sizes as the batch dimension of the logits"
- )
-
- if is_torch_tensor(target_sizes):
- target_sizes = target_sizes.numpy()
-
- semantic_segmentation = []
-
- for idx in range(len(logits)):
- resized_logits = torch.nn.functional.interpolate(
- logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
- )
- semantic_map = resized_logits[0].argmax(dim=0)
- semantic_segmentation.append(semantic_map)
- else:
- semantic_segmentation = logits.argmax(dim=1)
- semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
-
- return semantic_segmentation
+BeitFeatureExtractor = BeitImageProcessor
diff --git a/src/transformers/models/beit/image_processing_beit.py b/src/transformers/models/beit/image_processing_beit.py
new file mode 100644
index 0000000000..b16b94f281
--- /dev/null
+++ b/src/transformers/models/beit/image_processing_beit.py
@@ -0,0 +1,525 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Beit."""
+
+import warnings
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+
+from transformers.utils import is_torch_available, is_torch_tensor, is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+class BeitImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a BEiT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
+ is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in the
+ `preprocess` method.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
+ Can be overridden by the `crop_size` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ The mean to use if normalizing the image. This is a float or list of floats of length of the number of
+ channels of the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ The standard deviation to use if normalizing the image. This is a float or list of floats of length of the
+ number of channels of the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
+ used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
+ background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
+ `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_rescale: bool = True,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_reduce_labels: bool = False,
+ **kwargs
+ ) -> None:
+ if "reduce_labels" in kwargs:
+ warnings.warn(
+ "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use"
+ " `do_reduce_labels` instead.",
+ FutureWarning,
+ )
+ do_reduce_labels = kwargs.pop("reduce_labels")
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 256, "width": 256}
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.do_reduce_labels = do_reduce_labels
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image to (size["height"], size["width"]).
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
+ return resize(
+ image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
+ )
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Center crop an image to (size["height"], size["width"]). If the input size is smaller than `size` along any
+ edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ image_mean (`float` or `List[float]`):
+ Image mean.
+ image_std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def reduce_label(self, label: ImageInput) -> np.ndarray:
+ label = to_numpy_array(label)
+ # Avoid using underflow conversion
+ label[label == 0] = 255
+ label = label - 1
+ label[label == 254] = 255
+ return label
+
+ def _preprocess(
+ self,
+ image: ImageInput,
+ do_reduce_labels: bool = None,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: bool = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ ):
+ if do_reduce_labels:
+ image = self.reduce_label(image)
+
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std)
+
+ return image
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: bool = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """Preprocesses a single image."""
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+ image = self._preprocess(
+ image,
+ do_reduce_labels=False,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ )
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format)
+ return image
+
+ def _preprocess_segmentation_map(
+ self,
+ segmentation_map: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: bool = None,
+ crop_size: Dict[str, int] = None,
+ do_reduce_labels: bool = None,
+ ):
+ """Preprocesses a single segmentation map."""
+ # All transformations expect numpy arrays.
+ segmentation_map = to_numpy_array(segmentation_map)
+ # Add an axis to the segmentation maps for transformations.
+ if segmentation_map.ndim == 2:
+ segmentation_map = segmentation_map[None, ...]
+ added_dimension = True
+ else:
+ added_dimension = False
+ segmentation_map = self._preprocess(
+ image=segmentation_map,
+ do_reduce_labels=do_reduce_labels,
+ do_resize=do_resize,
+ resample=resample,
+ size=size,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_normalize=False,
+ do_rescale=False,
+ )
+ # Remove extra axis if added
+ if added_dimension:
+ segmentation_map = np.squeeze(segmentation_map, axis=0)
+ segmentation_map = segmentation_map.astype(np.int64)
+ return segmentation_map
+
+ def __call__(self, images, segmentation_maps=None, **kwargs):
+ # Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
+ # be passed in as positional arguments.
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: bool = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_reduce_labels: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
+ padded with zeros and then cropped
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
+ ADE20k). The background label will be replaced by 255.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size)
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
+
+ if not is_batched(images):
+ images = [images]
+ segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if segmentation_maps is not None and not valid_images(segmentation_maps):
+ raise ValueError(
+ "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None or resample is None:
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_center_crop and crop_size is None:
+ raise ValueError("Crop size must be specified if do_center_crop is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ images = [
+ self._preprocess_image(
+ image=img,
+ do_resize=do_resize,
+ do_center_crop=do_center_crop,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ resample=resample,
+ size=size,
+ rescale_factor=rescale_factor,
+ crop_size=crop_size,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ )
+ for img in images
+ ]
+
+ data = {"pixel_values": images}
+
+ if segmentation_maps is not None:
+ segmentation_maps = [
+ self._preprocess_segmentation_map(
+ segmentation_map=segmentation_map,
+ do_reduce_labels=do_reduce_labels,
+ do_resize=do_resize,
+ resample=resample,
+ size=size,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ )
+ for segmentation_map in segmentation_maps
+ ]
+ data["labels"] = segmentation_maps
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
+ """
+ Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
+
+ Args:
+ outputs ([`BeitForSemanticSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
+ None, predictions will not be resized.
+ Returns:
+ semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
+ """
+ # TODO: add support for other frameworks
+ logits = outputs.logits
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if len(logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ if is_torch_tensor(target_sizes):
+ target_sizes = target_sizes.numpy()
+
+ semantic_segmentation = []
+
+ for idx in range(len(logits)):
+ resized_logits = torch.nn.functional.interpolate(
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = logits.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
diff --git a/src/transformers/models/clip/feature_extraction_clip.py b/src/transformers/models/clip/feature_extraction_clip.py
index 8afe60337b..51c446e99b 100644
--- a/src/transformers/models/clip/feature_extraction_clip.py
+++ b/src/transformers/models/clip/feature_extraction_clip.py
@@ -14,155 +14,11 @@
# limitations under the License.
"""Feature extractor class for CLIP."""
-from typing import List, Optional, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
-from ...utils import TensorType, logging
+from ...utils import logging
+from .image_processing_clip import CLIPImageProcessor
logger = logging.get_logger(__name__)
-class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a CLIP feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size (`int`, *optional*, defaults to 224):
- Resize the input to the given size. Only has an effect if `do_resize` is set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_center_crop (`bool`, *optional*, defaults to `True`):
- Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
- image is padded with 0's and then center cropped.
- crop_size (`int`, *optional*, defaults to 224):
- Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with `image_mean` and `image_std`.
- image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- convert_rgb (`bool`, defaults to `True`):
- Whether or not to convert `PIL.Image.Image` into `RGB` format
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=224,
- resample=PILImageResampling.BICUBIC,
- do_center_crop=True,
- crop_size=224,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- do_convert_rgb=True,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_center_crop = do_center_crop
- self.crop_size = crop_size
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
- self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
- self.do_convert_rgb = do_convert_rgb
-
- def __call__(
- self,
- images: Union[
- Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
- ],
- return_tensors: Optional[Union[str, TensorType]] = None,
- **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model.
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (convert rgb + resizing + center cropping + normalization)
- if self.do_convert_rgb:
- images = [self.convert_rgb(image) for image in images]
- if self.do_resize and self.size is not None and self.resample is not None:
- images = [
- self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False)
- for image in images
- ]
- if self.do_center_crop and self.crop_size is not None:
- images = [self.center_crop(image, self.crop_size) for image in images]
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+CLIPFeatureExtractor = CLIPImageProcessor
diff --git a/src/transformers/models/clip/image_processing_clip.py b/src/transformers/models/clip/image_processing_clip.py
new file mode 100644
index 0000000000..b81d21e3fc
--- /dev/null
+++ b/src/transformers/models/clip/image_processing_clip.py
@@ -0,0 +1,342 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for CLIP."""
+
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ center_crop,
+ get_resize_output_image_size,
+ normalize,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images
+from ...utils import logging
+from ...utils.import_utils import is_vision_available
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_vision_available():
+ import PIL
+
+
+def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
+ """
+ Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
+
+ Args:
+ image (`PIL.Image.Image`):
+ The image to convert.
+ """
+ if not isinstance(image, PIL.Image.Image):
+ return image
+
+ return image.convert("RGB")
+
+
+class CLIPImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a CLIP image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
+ `preprocess` method.
+ crop_size (`Dict[str, int]` *optional*, defaults to 224):
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
+ method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize:
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Image standard deviation.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = True,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
+ self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
+ self.do_convert_rgb = do_convert_rgb
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
+ resized to keep the input aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size, default_to_square=False)
+ if "shortest_edge" not in size:
+ raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
+ output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
+ return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Center crop an image. If the image is too small to be cropped to the size given, it will be padded (so the
+ returned result will always be of size `size`).
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Size of the output image in the form of a dictionary with keys `height` and `width`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ image_mean (`float` or `List[float]`):
+ Image mean.
+ image_std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: bool = None,
+ crop_size: int = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: defaults to the channel dimension format of the input image.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size)
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None:
+ raise ValueError("Size must be specified if do_resize is True.")
+
+ if do_center_crop and crop_size is None:
+ raise ValueError("Crop size must be specified if do_center_crop is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ # PIL RGBA images are converted to RGB
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_resize:
+ images = [self.resize(image=image, size=size, resample=resample) for image in images]
+
+ if do_center_crop:
+ images = [self.center_crop(image=image, size=crop_size) for image in images]
+
+ if do_rescale:
+ images = [self.rescale(image=image, scale=rescale_factor) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
+
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/src/transformers/models/convnext/feature_extraction_convnext.py b/src/transformers/models/convnext/feature_extraction_convnext.py
index 516c7e26f8..c807c37436 100644
--- a/src/transformers/models/convnext/feature_extraction_convnext.py
+++ b/src/transformers/models/convnext/feature_extraction_convnext.py
@@ -14,157 +14,11 @@
# limitations under the License.
"""Feature extractor class for ConvNeXT."""
-from typing import Optional, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import (
- IMAGENET_DEFAULT_MEAN,
- IMAGENET_DEFAULT_STD,
- ImageFeatureExtractionMixin,
- ImageInput,
- is_torch_tensor,
-)
-from ...utils import TensorType, logging
+from ...utils import logging
+from .image_processing_convnext import ConvNextImageProcessor
logger = logging.get_logger(__name__)
-class ConvNextFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a ConvNeXT feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize (and optionally center crop) the input to a certain `size`.
- size (`int`, *optional*, defaults to 224):
- Resize the input to the given size. If 384 or larger, the image is resized to (`size`, `size`). Else, the
- smaller edge of the image will be matched to int(`size`/ `crop_pct`), after which the image is cropped to
- `size`. Only has an effect if `do_resize` is set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- crop_pct (`float`, *optional*):
- The percentage of the image to crop. If `None`, then a cropping percentage of 224 / 256 is used. Only has
- an effect if `do_resize` is set to `True` and `size` < 384.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with mean and standard deviation.
- image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=224,
- resample=PILImageResampling.BICUBIC,
- crop_pct=None,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.crop_pct = crop_pct
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
-
- def __call__(
- self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (resizing and optional center cropping + normalization)
- if self.do_resize and self.size is not None:
- if self.size >= 384:
- # warping (no cropping) when evaluated at 384 or larger
- images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
- else:
- if self.crop_pct is None:
- self.crop_pct = 224 / 256
- size = int(self.size / self.crop_pct)
- # to maintain same ratio w.r.t. 224 images
- images = [
- self.resize(image=image, size=size, default_to_square=False, resample=self.resample)
- for image in images
- ]
- images = [self.center_crop(image=image, size=self.size) for image in images]
-
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+ConvNextFeatureExtractor = ConvNextImageProcessor
diff --git a/src/transformers/models/convnext/image_processing_convnext.py b/src/transformers/models/convnext/image_processing_convnext.py
new file mode 100644
index 0000000000..5058e87963
--- /dev/null
+++ b/src/transformers/models/convnext/image_processing_convnext.py
@@ -0,0 +1,310 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for ConvNeXT."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from transformers.utils import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ center_crop,
+ get_resize_output_image_size,
+ normalize,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class ConvNextImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a ConvNeXT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
+ by `do_resize` in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
+ Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is
+ resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will
+ be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to
+ `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can
+ be overriden by `size` in the `preprocess` method.
+ crop_pct (`float` *optional*, defaults to 244 / 256):
+ Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
+ overriden by `crop_pct` in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ crop_pct: float = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 384}
+ size = get_size_dict(size, default_to_square=False)
+
+ self.do_resize = do_resize
+ self.size = size
+ # Default value set here for backwards compatibility where the value in config is None
+ self.crop_pct = crop_pct if crop_pct is not None else 224 / 256
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ crop_pct: float,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
+ `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
+ Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
+ after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
+ crop_pct (`float`):
+ Percentage of the image to crop. Only has an effect if size < 384.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size, default_to_square=False)
+ if "shortest_edge" not in size:
+ raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
+ shortest_edge = size["shortest_edge"]
+
+ if shortest_edge < 384:
+ # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
+ resize_shortest_edge = int(shortest_edge / crop_pct)
+ resize_size = get_resize_output_image_size(image, size=resize_shortest_edge, default_to_square=False)
+ image = resize(image=image, size=resize_size, resample=resample, data_format=data_format, **kwargs)
+ # then crop to (shortest_edge, shortest_edge)
+ return center_crop(image=image, size=(shortest_edge, shortest_edge), data_format=data_format, **kwargs)
+ else:
+ # warping (no cropping) when evaluated at 384 or larger
+ return resize(
+ image, size=(shortest_edge, shortest_edge), resample=resample, data_format=data_format, **kwargs
+ )
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ image_mean (`float` or `List[float]`):
+ Image mean.
+ image_std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ crop_pct: float = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
+ is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
+ image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
+ `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
+ crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
+ Percentage of the image to crop if size < 384.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ crop_pct = crop_pct if crop_pct is not None else self.crop_pct
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None or resample is None:
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_resize and size["shortest_edge"] < 384 and crop_pct is None:
+ raise ValueError("crop_pct must be specified if size < 384.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_resize:
+ images = [self.resize(image=image, size=size, crop_pct=crop_pct, resample=resample) for image in images]
+
+ if do_rescale:
+ images = [self.rescale(image=image, scale=rescale_factor) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
+
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py
index 72a8be4bef..1c9f58f4a6 100644
--- a/src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py
@@ -308,7 +308,7 @@ def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_fo
model = CvtForImageClassification(config)
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k")
- feature_extractor.size = image_size
+ feature_extractor.size["shortest_edge"] = image_size
original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"))
huggingface_weights = OrderedDict()
diff --git a/src/transformers/models/deit/feature_extraction_deit.py b/src/transformers/models/deit/feature_extraction_deit.py
index e0008ff87d..f47b6fc212 100644
--- a/src/transformers/models/deit/feature_extraction_deit.py
+++ b/src/transformers/models/deit/feature_extraction_deit.py
@@ -14,150 +14,10 @@
# limitations under the License.
"""Feature extractor class for DeiT."""
-from typing import Optional, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import (
- IMAGENET_DEFAULT_MEAN,
- IMAGENET_DEFAULT_STD,
- ImageFeatureExtractionMixin,
- ImageInput,
- is_torch_tensor,
-)
-from ...utils import TensorType, logging
+from ...utils import logging
+from .image_processing_deit import DeiTImageProcessor
logger = logging.get_logger(__name__)
-
-class DeiTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a DeiT feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size (`int` or `Tuple(int)`, *optional*, defaults to 256):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
- set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_center_crop (`bool`, *optional*, defaults to `True`):
- Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
- image is padded with 0's and then center cropped.
- crop_size (`int`, *optional*, defaults to 224):
- Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with `image_mean` and `image_std`.
- image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=256,
- resample=PILImageResampling.BICUBIC,
- do_center_crop=True,
- crop_size=224,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_center_crop = do_center_crop
- self.crop_size = crop_size
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
-
- def __call__(
- self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (resizing + center cropping + normalization)
- if self.do_resize and self.size is not None and self.resample is not None:
- images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
- if self.do_center_crop and self.crop_size is not None:
- images = [self.center_crop(image, self.crop_size) for image in images]
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+DeiTFeatureExtractor = DeiTImageProcessor
diff --git a/src/transformers/models/deit/image_processing_deit.py b/src/transformers/models/deit/image_processing_deit.py
new file mode 100644
index 0000000000..0b561765ff
--- /dev/null
+++ b/src/transformers/models/deit/image_processing_deit.py
@@ -0,0 +1,315 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for DeiT."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from transformers.utils import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeiTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a DeiT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in `preprocess`.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
+ Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
+ resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
+ is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PIL.Image.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_rescale: bool = True,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 256, "width": 256}
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PIL.Image.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])` using the specified resampling filter.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
+ return resize(
+ image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
+ )
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Center crop an image to `(crop_size["height"], crop_size["width"])`. If the input size is smaller than
+ `crop_size` along any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ image_mean (`float` or `List[float]`):
+ Image mean.
+ image_std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample=None,
+ do_center_crop: bool = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after `resize`.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ PILImageResampling filter to use if resizing the image Only has an effect if `do_resize` is set to
+ `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
+ padded with zeros and then cropped
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - `None`: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size)
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None or resample is None:
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_center_crop and crop_size is None:
+ raise ValueError("Crop size must be specified if do_center_crop is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_resize:
+ images = [self.resize(image=image, size=size, resample=resample) for image in images]
+
+ if do_center_crop:
+ images = [self.center_crop(image=image, size=crop_size) for image in images]
+
+ if do_rescale:
+ images = [self.rescale(image=image, scale=rescale_factor) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
+
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/src/transformers/models/dpt/feature_extraction_dpt.py b/src/transformers/models/dpt/feature_extraction_dpt.py
index 14aeb39478..1836f132af 100644
--- a/src/transformers/models/dpt/feature_extraction_dpt.py
+++ b/src/transformers/models/dpt/feature_extraction_dpt.py
@@ -14,235 +14,11 @@
# limitations under the License.
"""Feature extractor class for DPT."""
-from typing import List, Optional, Tuple, Union
+from ...utils import logging
+from .image_processing_dpt import DPTImageProcessor
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import (
- IMAGENET_STANDARD_MEAN,
- IMAGENET_STANDARD_STD,
- ImageFeatureExtractionMixin,
- ImageInput,
- is_torch_tensor,
-)
-from ...utils import TensorType, is_torch_available, logging
-
-
-if is_torch_available():
- import torch
logger = logging.get_logger(__name__)
-class DPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a DPT feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size ('int' or `Tuple(int)`, *optional*, defaults to 384):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
- set to `True`.
- ensure_multiple_of (`int`, *optional*, defaults to 1):
- Ensure that the input is resized to a multiple of this value. Only has an effect if `do_resize` is set to
- `True`.
- keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
- Whether to keep the aspect ratio of the input. Only has an effect if `do_resize` is set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with mean and standard deviation.
- image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=384,
- keep_aspect_ratio=False,
- ensure_multiple_of=1,
- resample=PILImageResampling.BILINEAR,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.keep_aspect_ratio = keep_aspect_ratio
- self.ensure_multiple_of = ensure_multiple_of
- self.resample = resample
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
-
- def constrain_to_multiple_of(self, size, min_val=0, max_val=None):
- y = (np.round(size / self.ensure_multiple_of) * self.ensure_multiple_of).astype(int)
-
- if max_val is not None and y > max_val:
- y = (np.floor(size / self.ensure_multiple_of) * self.ensure_multiple_of).astype(int)
-
- if y < min_val:
- y = (np.ceil(size / self.ensure_multiple_of) * self.ensure_multiple_of).astype(int)
-
- return y
-
- def update_size(self, image):
- image = self.to_pil_image(image)
- width, height = image.size
-
- size = self.size
-
- if isinstance(size, list):
- size = tuple(size)
-
- if isinstance(size, int) or len(size) == 1:
- size = (size, size)
-
- # determine new width and height
- scale_width = size[0] / width
- scale_height = size[1] / height
-
- if self.keep_aspect_ratio:
- # scale as least as possbile
- if abs(1 - scale_width) < abs(1 - scale_height):
- # fit width
- scale_height = scale_width
- else:
- # fit height
- scale_width = scale_height
- else:
- new_width = self.constrain_to_multiple_of(scale_width * width)
- new_height = self.constrain_to_multiple_of(scale_height * height)
-
- return (new_width, new_height)
-
- def __call__(
- self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_tensors (`str` or [`~file_utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (resizing + normalization)
- if self.do_resize and self.size is not None:
- for idx, image in enumerate(images):
- size = self.update_size(image)
- images[idx] = self.resize(image, size=size, resample=self.resample)
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
-
- def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
- """
- Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
-
- Args:
- outputs ([`DPTForSemanticSegmentation`]):
- Raw outputs of the model.
- target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
- List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
- None, predictions will not be resized.
- Returns:
- semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
- segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
- specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
- """
- logits = outputs.logits
-
- # Resize logits and compute semantic segmentation maps
- if target_sizes is not None:
- if len(logits) != len(target_sizes):
- raise ValueError(
- "Make sure that you pass in as many target sizes as the batch dimension of the logits"
- )
-
- if is_torch_tensor(target_sizes):
- target_sizes = target_sizes.numpy()
-
- semantic_segmentation = []
-
- for idx in range(len(logits)):
- resized_logits = torch.nn.functional.interpolate(
- logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
- )
- semantic_map = resized_logits[0].argmax(dim=0)
- semantic_segmentation.append(semantic_map)
- else:
- semantic_segmentation = logits.argmax(dim=1)
- semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
-
- return semantic_segmentation
+DPTFeatureExtractor = DPTImageProcessor
diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py
new file mode 100644
index 0000000000..9653194cbc
--- /dev/null
+++ b/src/transformers/models/dpt/image_processing_dpt.py
@@ -0,0 +1,384 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for DPT."""
+
+import math
+from typing import Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+
+from transformers.utils import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ is_batched,
+ is_torch_available,
+ is_torch_tensor,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_resize_output_image_size(
+ input_image: np.ndarray, output_size: Union[int, Iterable[int]], keep_aspect_ratio: bool, multiple: int
+) -> Tuple[int, int]:
+ def constraint_to_multiple_of(val, multiple, min_val=0, max_val=None):
+ x = round(val / multiple) * multiple
+
+ if max_val is not None and x > max_val:
+ x = math.floor(val / multiple) * multiple
+
+ if x < min_val:
+ x = math.ceil(val / multiple) * multiple
+
+ return x
+
+ output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
+
+ input_height, input_width = get_image_size(input_image)
+ output_height, output_width = output_size
+
+ # determine new height and width
+ scale_height = output_height / input_height
+ scale_width = output_width / input_width
+
+ if keep_aspect_ratio:
+ # scale as little as possible
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+
+ new_height = constraint_to_multiple_of(scale_height * input_height, multiple=multiple)
+ new_width = constraint_to_multiple_of(scale_width * input_width, multiple=multiple)
+
+ return (new_height, new_width)
+
+
+class DPTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a DPT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`):
+ Size of the image after resizing. Can be overidden by `size` in `preprocess`.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
+ be overidden by `keep_aspect_ratio` in `preprocess`.
+ ensure_multiple_of (`int`, *optional*, defaults to `1`):
+ If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
+ by `ensure_multiple_of` in `preprocess`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
+ `preprocess`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ keep_aspect_ratio: bool = False,
+ ensure_multiple_of: int = 1,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 384, "width": 384}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.size = size
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.ensure_multiple_of = ensure_multiple_of
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ keep_aspect_ratio: bool = False,
+ ensure_multiple_of: int = 1,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
+ is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
+ set, the image is resized to a size that is a multiple of this value.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Target size of the output image.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
+ ensure_multiple_of (`int`, *optional*, defaults to `1`):
+ The image is resized to a size that is a multiple of this value.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
+ specified in `size`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
+ output_size = get_resize_output_image_size(
+ image,
+ output_size=(size["height"], size["width"]),
+ keep_aspect_ratio=keep_aspect_ratio,
+ multiple=ensure_multiple_of,
+ )
+ return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ image_mean (`float` or `List[float]`):
+ Image mean.
+ image_std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: int = None,
+ keep_aspect_ratio: bool = None,
+ ensure_multiple_of: int = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after reszing. If `keep_aspect_ratio` is `True`, the image is resized to the largest
+ possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is
+ resized to a size that is a multiple of this value.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
+ Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If
+ True, the image will be resized to keep the aspect ratio and the size will be the maximum possible.
+ ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
+ Ensure that the image size is a multiple of this value.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
+ ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None or resample is None:
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_resize:
+ images = [self.resize(image=image, size=size, resample=resample) for image in images]
+
+ if do_rescale:
+ images = [self.rescale(image=image, scale=rescale_factor) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
+
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
+ """
+ Args:
+ Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
+ outputs ([`DPTForSemanticSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
+ predictions will not be resized.
+ Returns:
+ semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
+ """
+ # TODO: add support for other frameworks
+ logits = outputs.logits
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if len(logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ if is_torch_tensor(target_sizes):
+ target_sizes = target_sizes.numpy()
+
+ semantic_segmentation = []
+
+ for idx in range(len(logits)):
+ resized_logits = torch.nn.functional.interpolate(
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = logits.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
diff --git a/src/transformers/models/flava/feature_extraction_flava.py b/src/transformers/models/flava/feature_extraction_flava.py
index 7f9456f9e0..0cab53231e 100644
--- a/src/transformers/models/flava/feature_extraction_flava.py
+++ b/src/transformers/models/flava/feature_extraction_flava.py
@@ -14,344 +14,10 @@
# limitations under the License.
"""Feature extractor class for FLAVA."""
-import math
-import random
-from functools import lru_cache
-from typing import Any, List, Optional, Tuple, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
-from ...utils import TensorType, logging
+from ...utils import logging
+from .image_processing_flava import FlavaImageProcessor
logger = logging.get_logger(__name__)
-
-# These values are taken from CLIP
-FLAVA_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
-FLAVA_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]
-FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
-FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
-LOGIT_LAPLACE_EPS: float = 0.1
-
-
-# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
-class FlavaMaskingGenerator:
- def __init__(
- self,
- input_size: Union[int, Tuple[int, int]] = 14,
- total_mask_patches: int = 75,
- mask_group_max_patches: Optional[int] = None,
- mask_group_min_patches: int = 16,
- mask_group_min_aspect_ratio: Optional[float] = 0.3,
- mask_group_max_aspect_ratio: float = None,
- ):
- if not isinstance(input_size, tuple):
- input_size = (input_size,) * 2
- self.height, self.width = input_size
-
- self.num_patches = self.height * self.width
- self.total_mask_patches = total_mask_patches
-
- self.mask_group_min_patches = mask_group_min_patches
- self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
-
- mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
- self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
-
- def __repr__(self):
- repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
- self.height,
- self.width,
- self.mask_group_min_patches,
- self.mask_group_max_patches,
- self.total_mask_patches,
- self.log_aspect_ratio[0],
- self.log_aspect_ratio[1],
- )
- return repr_str
-
- def get_shape(self):
- return self.height, self.width
-
- def _mask(self, mask, max_mask_patches):
- delta = 0
- for _attempt in range(10):
- target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
- aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
- height = int(round(math.sqrt(target_area * aspect_ratio)))
- width = int(round(math.sqrt(target_area / aspect_ratio)))
- if width < self.width and height < self.height:
- top = random.randint(0, self.height - height)
- left = random.randint(0, self.width - width)
-
- num_masked = mask[top : top + height, left : left + width].sum()
- # Overlap
- if 0 < height * width - num_masked <= max_mask_patches:
- for i in range(top, top + height):
- for j in range(left, left + width):
- if mask[i, j] == 0:
- mask[i, j] = 1
- delta += 1
-
- if delta > 0:
- break
- return delta
-
- def __call__(self):
- mask = np.zeros(shape=self.get_shape(), dtype=int)
- mask_count = 0
- while mask_count < self.total_mask_patches:
- max_mask_patches = self.total_mask_patches - mask_count
- max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
-
- delta = self._mask(mask, max_mask_patches)
- if delta == 0:
- break
- else:
- mask_count += delta
-
- return mask
-
-
-class FlavaFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a FLAVA feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size (`int`, *optional*, defaults to 224):
- Resize the input to the given size. Only has an effect if `do_resize` is set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_center_crop (`bool`, *optional*, defaults to `True`):
- Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
- image is padded with 0's and then center cropped.
- crop_size (`int`, *optional*, defaults to 224):
- Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with `image_mean` and `image_std`.
- image_mean (`Tuple[float, float, float]`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`Tuple[float, float, float]`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- input_size_patches (`int`, *optional*, defaults to 14):
- Number of patches in the image in height and width direction. 14x14 = 196 total patches.
- total_mask_patches (`int`, *optional*, defaults to 75):
- Total number of patches that should be masked.
- mask_group_min_patches (`int`, *optional*, defaults to 16):
- Minimum number of patches that should be masked.
- mask_group_max_patches (`int`, *optional*, defaults to None):
- Maximum number of patches that should be masked.
- mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
- Minimum aspect ratio of the mask window.
- mask_group_max_aspect_ratio (`float`, *optional*, defaults to None):
- Maximum aspect ratio of the mask window
- codebook_do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input for codebook to a certain `codebook_size`.
- codebook_size (`int`, *optional*, defaults to 224):
- Resize the input for codebook to the given size. Only has an effect if `codebook_do_resize` is set to
- `True`.
- codebook_resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
- Whether to crop the input for codebook at the center. If the input size is smaller than
- `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped.
- codebook_crop_size (`int`, *optional*, defaults to 224):
- Desired output size for codebook input when applying center-cropping. Only has an effect if
- `codebook_do_center_crop` is set to `True`.
- codebook_do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`.
- codebook_image_mean (`Tuple[float, float, float]`, *optional*, defaults to `[0, 0, 0]`):
- The sequence of means for each channel, to be used when normalizing images for codebook.
- codebook_image_std (`Tuple[float, float, float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of standard deviations for each channel, to be used when normalizing images for codebook.
-
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize: bool = True,
- size: Union[int, Tuple[int, int]] = 224,
- resample: int = PILImageResampling.BICUBIC,
- do_center_crop: bool = True,
- crop_size: Union[int, Tuple[int, int]] = 224,
- do_normalize: bool = True,
- image_mean: Tuple[float, float, float] = FLAVA_IMAGE_MEAN,
- image_std: Tuple[float, float, float] = FLAVA_IMAGE_STD,
- # Mask related params
- input_size_patches: int = 14,
- total_mask_patches: int = 75,
- mask_group_min_patches: int = 16,
- mask_group_max_patches: Optional[int] = None,
- mask_group_min_aspect_ratio: float = 0.3,
- mask_group_max_aspect_ratio: Optional[float] = None,
- # Codebook related params
- codebook_do_resize: bool = True,
- codebook_size: bool = 112,
- codebook_resample: int = PILImageResampling.LANCZOS,
- codebook_do_center_crop: bool = True,
- codebook_crop_size: int = 112,
- codebook_do_map_pixels: bool = True,
- codebook_do_normalize: bool = True,
- codebook_image_mean: Tuple[float, float, float] = FLAVA_CODEBOOK_MEAN,
- codebook_image_std: Tuple[float, float, float] = FLAVA_CODEBOOK_STD,
- **kwargs: Any,
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_center_crop = do_center_crop
- self.crop_size = crop_size
- self.do_normalize = do_normalize
- self.image_mean = image_mean
- self.image_std = image_std
-
- self.input_size_patches = input_size_patches
- self.total_mask_patches = total_mask_patches
- self.mask_group_min_patches = mask_group_min_patches
- self.mask_group_max_patches = mask_group_max_patches
- self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
- self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
-
- self.codebook_do_resize = codebook_do_resize
- self.codebook_size = codebook_size
- self.codebook_resample = codebook_resample
- self.codebook_do_center_crop = codebook_do_center_crop
- self.codebook_crop_size = codebook_crop_size
- self.codebook_do_map_pixels = codebook_do_map_pixels
- self.codebook_do_normalize = codebook_do_normalize
- self.codebook_image_mean = codebook_image_mean
- self.codebook_image_std = codebook_image_std
-
- @property
- @lru_cache()
- def masking_generator(self):
- return FlavaMaskingGenerator(
- input_size=self.input_size_patches,
- total_mask_patches=self.total_mask_patches,
- mask_group_min_patches=self.mask_group_min_patches,
- mask_group_max_patches=self.mask_group_max_patches,
- mask_group_min_aspect_ratio=self.mask_group_min_aspect_ratio,
- mask_group_max_aspect_ratio=self.mask_group_max_aspect_ratio,
- )
-
- def map_pixels(self, x):
- return (1 - 2 * LOGIT_LAPLACE_EPS) * x + LOGIT_LAPLACE_EPS
-
- def __call__(
- self,
- images: Union[
- Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
- ],
- return_image_mask: Optional[bool] = None,
- return_codebook_pixels: Optional[bool] = None,
- return_tensors: Optional[Union[str, TensorType]] = None,
- **kwargs: Any
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_image_mask (`bool`, *optional*, defaults to None):
- If True, the processor will return `bool_masked_pos` suggesting masks for image's patch version.
-
- return_codebook_pixels (`bool`, *optional*, defaults to None):
- If True, the processor will return `codebook_pixel_values` providing image pixels to be used with the
- default FLAVA codebook. Used in pretraining by Masked Image Modeling (MIM) loss.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model.
- """
- # Input type checking for clearer error
- if isinstance(images, (list, tuple)) and len(images) != 0:
- self._ensure_format_supported(images[0])
- else:
- self._ensure_format_supported(images)
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- images_for_codebook = images
-
- # transformations (resizing + center cropping + normalization)
- if self.do_resize and self.size is not None and self.resample is not None:
- images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
- if self.do_center_crop and self.crop_size is not None:
- images = [self.center_crop(image, self.crop_size) for image in images]
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
- # return as BatchFeature
- data = {"pixel_values": images}
-
- if return_codebook_pixels:
- images = images_for_codebook
- if self.codebook_do_resize and self.codebook_size is not None and self.codebook_resample is not None:
- images = [
- self.resize(image=image, size=self.codebook_size, resample=self.codebook_resample)
- for image in images
- ]
- if self.codebook_do_center_crop and self.codebook_crop_size is not None:
- images = [self.center_crop(image, self.codebook_crop_size) for image in images]
- if self.codebook_do_normalize:
- images = [
- self.normalize(image=image, mean=self.codebook_image_mean, std=self.codebook_image_std)
- for image in images
- ]
- if self.codebook_do_map_pixels:
- images = [self.map_pixels(image) for image in images]
-
- data["codebook_pixel_values"] = images
-
- if return_image_mask:
- masks = [self.masking_generator() for _ in images]
- data["bool_masked_pos"] = masks
-
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+FlavaFeatureExtractor = FlavaImageProcessor
diff --git a/src/transformers/models/flava/image_processing_flava.py b/src/transformers/models/flava/image_processing_flava.py
new file mode 100644
index 0000000000..6c0fa1bff0
--- /dev/null
+++ b/src/transformers/models/flava/image_processing_flava.py
@@ -0,0 +1,696 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Flava."""
+
+import math
+import random
+from functools import lru_cache
+from typing import Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+
+from transformers.utils import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
+from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+# These values are taken from CLIP
+FLAVA_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
+FLAVA_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]
+FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
+FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
+LOGIT_LAPLACE_EPS: float = 0.1
+
+
+# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
+class FlavaMaskingGenerator:
+ def __init__(
+ self,
+ input_size: Union[int, Tuple[int, int]] = 14,
+ total_mask_patches: int = 75,
+ mask_group_max_patches: Optional[int] = None,
+ mask_group_min_patches: int = 16,
+ mask_group_min_aspect_ratio: Optional[float] = 0.3,
+ mask_group_max_aspect_ratio: float = None,
+ ):
+ if not isinstance(input_size, tuple):
+ input_size = (input_size,) * 2
+ self.height, self.width = input_size
+
+ self.num_patches = self.height * self.width
+ self.total_mask_patches = total_mask_patches
+
+ self.mask_group_min_patches = mask_group_min_patches
+ self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
+
+ mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
+ self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
+
+ def __repr__(self):
+ repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
+ self.height,
+ self.width,
+ self.mask_group_min_patches,
+ self.mask_group_max_patches,
+ self.total_mask_patches,
+ self.log_aspect_ratio[0],
+ self.log_aspect_ratio[1],
+ )
+ return repr_str
+
+ def get_shape(self):
+ return self.height, self.width
+
+ def _mask(self, mask, max_mask_patches):
+ delta = 0
+ for _attempt in range(10):
+ target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ height = int(round(math.sqrt(target_area * aspect_ratio)))
+ width = int(round(math.sqrt(target_area / aspect_ratio)))
+ if width < self.width and height < self.height:
+ top = random.randint(0, self.height - height)
+ left = random.randint(0, self.width - width)
+
+ num_masked = mask[top : top + height, left : left + width].sum()
+ # Overlap
+ if 0 < height * width - num_masked <= max_mask_patches:
+ for i in range(top, top + height):
+ for j in range(left, left + width):
+ if mask[i, j] == 0:
+ mask[i, j] = 1
+ delta += 1
+
+ if delta > 0:
+ break
+ return delta
+
+ def __call__(self):
+ mask = np.zeros(shape=self.get_shape(), dtype=int)
+ mask_count = 0
+ while mask_count < self.total_mask_patches:
+ max_mask_patches = self.total_mask_patches - mask_count
+ max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
+
+ delta = self._mask(mask, max_mask_patches)
+ if delta == 0:
+ break
+ else:
+ mask_count += delta
+
+ return mask
+
+
+class FlavaImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Flava image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in `preprocess`.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the image after resizing. Can be overridden by the `size` parameter in `preprocess`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in
+ `preprocess`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the images. Can be overridden by the `do_center_crop` parameter in `preprocess`.
+ crop_size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of image after the center crop `(crop_size["height"], crop_size["width"])`. Can be overridden by the
+ `crop_size` parameter in `preprocess`.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in `preprocess`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in
+ `preprocess`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in `preprocess`.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ return_image_mask (`bool`, *optional*, defaults to `False`):
+ Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
+ input_size_patches (`int`, *optional*, defaults to 14):
+ Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
+ by the `input_size_patches` parameter in `preprocess`.
+ total_mask_patches (`int`, *optional*, defaults to 75):
+ Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
+ `preprocess`.
+ mask_group_min_patches (`int`, *optional*, defaults to 16):
+ Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
+ parameter in `preprocess`.
+ mask_group_max_patches (`int`, *optional*):
+ Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
+ parameter in `preprocess`.
+ mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
+ Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
+ in `preprocess`.
+ mask_group_max_aspect_ratio (`float`, *optional*):
+ Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
+ in `preprocess`.
+ codebook_do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
+ parameter in `preprocess`. `codebook_size`.
+ codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
+ `preprocess`.
+ codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
+ Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
+ parameter in `preprocess`.
+ codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to crop the input for codebook at the center. If the input size is smaller than
+ `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
+ overridden by the `codebook_do_center_crop` parameter in `preprocess`.
+ codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Desired output size for codebook input when applying center-cropping. Can be overridden by the
+ `codebook_crop_size` parameter in `preprocess`.
+ codebook_do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
+ overridden by the `codebook_do_rescale` parameter in `preprocess`.
+ codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
+ `codebook_rescale_factor` parameter in `preprocess`.
+ codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
+ Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
+ `codebook_do_map_pixels` parameter in `preprocess`.
+ codebook_do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
+ be overridden by the `codebook_do_normalize` parameter in `preprocess`.
+ codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
+ The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
+ by the `codebook_image_mean` parameter in `preprocess`.
+ codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
+ be overridden by the `codebook_image_std` parameter in `preprocess`.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, Iterable[float]]] = None,
+ image_std: Optional[Union[float, Iterable[float]]] = None,
+ # Mask related params
+ return_image_mask: bool = False,
+ input_size_patches: int = 14,
+ total_mask_patches: int = 75,
+ mask_group_min_patches: int = 16,
+ mask_group_max_patches: Optional[int] = None,
+ mask_group_min_aspect_ratio: float = 0.3,
+ mask_group_max_aspect_ratio: Optional[float] = None,
+ # Codebook related params
+ return_codebook_pixels: bool = False,
+ codebook_do_resize: bool = True,
+ codebook_size: bool = None,
+ codebook_resample: int = PILImageResampling.LANCZOS,
+ codebook_do_center_crop: bool = True,
+ codebook_crop_size: int = None,
+ codebook_do_rescale: bool = True,
+ codebook_rescale_factor: Union[int, float] = 1 / 255,
+ codebook_do_map_pixels: bool = True,
+ codebook_do_normalize: bool = True,
+ codebook_image_mean: Optional[Union[float, Iterable[float]]] = None,
+ codebook_image_std: Optional[Union[float, Iterable[float]]] = None,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 224, "width": 224}
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size)
+
+ codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
+ codebook_size = get_size_dict(codebook_size)
+ codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
+ codebook_crop_size = get_size_dict(codebook_crop_size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else FLAVA_IMAGE_MEAN
+ self.image_std = image_std if image_std is not None else FLAVA_IMAGE_STD
+
+ self.return_image_mask = return_image_mask
+ self.input_size_patches = input_size_patches
+ self.total_mask_patches = total_mask_patches
+ self.mask_group_min_patches = mask_group_min_patches
+ self.mask_group_max_patches = mask_group_max_patches
+ self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
+ self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
+
+ self.return_codebook_pixels = return_codebook_pixels
+ self.codebook_do_resize = codebook_do_resize
+ self.codebook_size = codebook_size
+ self.codebook_resample = codebook_resample
+ self.codebook_do_center_crop = codebook_do_center_crop
+ self.codebook_crop_size = codebook_crop_size
+ self.codebook_do_rescale = codebook_do_rescale
+ self.codebook_rescale_factor = codebook_rescale_factor
+ self.codebook_do_map_pixels = codebook_do_map_pixels
+ self.codebook_do_normalize = codebook_do_normalize
+ self.codebook_image_mean = codebook_image_mean
+ self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN
+ self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD
+
+ @lru_cache()
+ def masking_generator(
+ self,
+ input_size_patches,
+ total_mask_patches,
+ mask_group_min_patches,
+ mask_group_max_patches,
+ mask_group_min_aspect_ratio,
+ mask_group_max_aspect_ratio,
+ ) -> FlavaMaskingGenerator:
+ return FlavaMaskingGenerator(
+ input_size=input_size_patches,
+ total_mask_patches=total_mask_patches,
+ mask_group_min_patches=mask_group_min_patches,
+ mask_group_max_patches=mask_group_max_patches,
+ mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
+ mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
+ )
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must contain 'height' and 'width' keys. Got {size.keys()}")
+ return resize(
+ image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
+ )
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
+ any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ image_mean (`float` or `List[float]`):
+ Image mean.
+ image_std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def map_pixels(self, image: np.ndarray) -> np.ndarray:
+ return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: bool = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_map_pixels: bool = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ ) -> np.ndarray:
+ """Preprocesses a single image."""
+ if do_resize and size is None or resample is None:
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std)
+
+ if do_map_pixels:
+ image = self.map_pixels(image)
+
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format)
+ return image
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[Dict[str, int]] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ # Mask related params
+ return_image_mask: Optional[bool] = None,
+ input_size_patches: Optional[int] = None,
+ total_mask_patches: Optional[int] = None,
+ mask_group_min_patches: Optional[int] = None,
+ mask_group_max_patches: Optional[int] = None,
+ mask_group_min_aspect_ratio: Optional[float] = None,
+ mask_group_max_aspect_ratio: Optional[float] = None,
+ # Codebook related params
+ return_codebook_pixels: Optional[bool] = None,
+ codebook_do_resize: Optional[bool] = None,
+ codebook_size: Optional[Dict[str, int]] = None,
+ codebook_resample: Optional[int] = None,
+ codebook_do_center_crop: Optional[bool] = None,
+ codebook_crop_size: Optional[Dict[str, int]] = None,
+ codebook_do_rescale: Optional[bool] = None,
+ codebook_rescale_factor: Optional[float] = None,
+ codebook_do_map_pixels: Optional[bool] = None,
+ codebook_do_normalize: Optional[bool] = None,
+ codebook_image_mean: Optional[Iterable[float]] = None,
+ codebook_image_std: Optional[Iterable[float]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_image_mask (`bool`, *optional*, defaults to `self.return_image_mask`):
+ Whether to return the image mask.
+ input_size_patches (`int`, *optional*, defaults to `self.input_size_patches`):
+ Size of the patches to extract from the image.
+ total_mask_patches (`int`, *optional*, defaults to `self.total_mask_patches`):
+ Total number of patches to extract from the image.
+ mask_group_min_patches (`int`, *optional*, defaults to `self.mask_group_min_patches`):
+ Minimum number of patches to extract from the image.
+ mask_group_max_patches (`int`, *optional*, defaults to `self.mask_group_max_patches`):
+ Maximum number of patches to extract from the image.
+ mask_group_min_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_min_aspect_ratio`):
+ Minimum aspect ratio of the patches to extract from the image.
+ mask_group_max_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_max_aspect_ratio`):
+ Maximum aspect ratio of the patches to extract from the image.
+ return_codebook_pixels (`bool`, *optional*, defaults to `self.return_codebook_pixels`):
+ Whether to return the codebook pixels.
+ codebook_do_resize (`bool`, *optional*, defaults to `self.codebook_do_resize`):
+ Whether to resize the codebook pixels.
+ codebook_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_size`):
+ Size of the codebook pixels.
+ codebook_resample (`int`, *optional*, defaults to `self.codebook_resample`):
+ Resampling filter to use if resizing the codebook pixels. This can be one of the enum
+ `PILImageResampling`, Only has an effect if `codebook_do_resize` is set to `True`.
+ codebook_do_center_crop (`bool`, *optional*, defaults to `self.codebook_do_center_crop`):
+ Whether to center crop the codebook pixels.
+ codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_crop_size`):
+ Size of the center crop of the codebook pixels. Only has an effect if `codebook_do_center_crop` is set
+ to `True`.
+ codebook_do_rescale (`bool`, *optional*, defaults to `self.codebook_do_rescale`):
+ Whether to rescale the codebook pixels values between [0 - 1].
+ codebook_rescale_factor (`float`, *optional*, defaults to `self.codebook_rescale_factor`):
+ Rescale factor to rescale the codebook pixels by if `codebook_do_rescale` is set to `True`.
+ codebook_do_map_pixels (`bool`, *optional*, defaults to `self.codebook_do_map_pixels`):
+ Whether to map the codebook pixels values.
+ codebook_do_normalize (`bool`, *optional*, defaults to `self.codebook_do_normalize`):
+ Whether to normalize the codebook pixels.
+ codebook_image_mean (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_mean`):
+ Codebook pixels mean to normalize the codebook pixels by if `codebook_do_normalize` is set to `True`.
+ codebook_image_std (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_std`):
+ Codebook pixels standard deviation to normalize the codebook pixels by if `codebook_do_normalize` is
+ set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size)
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ return_image_mask = return_image_mask if return_image_mask is not None else self.return_image_mask
+ input_size_patches = input_size_patches if input_size_patches is not None else self.input_size_patches
+ total_mask_patches = total_mask_patches if total_mask_patches is not None else self.total_mask_patches
+ mask_group_min_patches = (
+ mask_group_min_patches if mask_group_min_patches is not None else self.mask_group_min_patches
+ )
+ mask_group_max_patches = (
+ mask_group_max_patches if mask_group_max_patches is not None else self.mask_group_max_patches
+ )
+ mask_group_min_aspect_ratio = (
+ mask_group_min_aspect_ratio
+ if mask_group_min_aspect_ratio is not None
+ else self.mask_group_min_aspect_ratio
+ )
+ mask_group_max_aspect_ratio = (
+ mask_group_max_aspect_ratio
+ if mask_group_max_aspect_ratio is not None
+ else self.mask_group_max_aspect_ratio
+ )
+
+ return_codebook_pixels = (
+ return_codebook_pixels if return_codebook_pixels is not None else self.return_codebook_pixels
+ )
+ codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize
+ codebook_size = codebook_size if codebook_size is not None else self.codebook_size
+ codebook_size = get_size_dict(codebook_size)
+ codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample
+ codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale
+ codebook_rescale_factor = (
+ codebook_rescale_factor if codebook_rescale_factor is not None else self.codebook_rescale_factor
+ )
+ codebook_do_center_crop = (
+ codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop
+ )
+ codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size
+ codebook_crop_size = get_size_dict(codebook_crop_size)
+ codebook_do_map_pixels = (
+ codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels
+ )
+ codebook_do_normalize = (
+ codebook_do_normalize if codebook_do_normalize is not None else self.codebook_do_normalize
+ )
+ codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else self.codebook_image_mean
+ codebook_image_std = codebook_image_std if codebook_image_std is not None else self.codebook_image_std
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ processed_images = [
+ self._preprocess_image(
+ image=img,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_map_pixels=False,
+ data_format=data_format,
+ )
+ for img in images
+ ]
+ data = {"pixel_values": processed_images}
+
+ if return_codebook_pixels:
+ codebook_images = [
+ self._preprocess_image(
+ image=img,
+ do_resize=codebook_do_resize,
+ size=codebook_size,
+ resample=codebook_resample,
+ do_center_crop=codebook_do_center_crop,
+ crop_size=codebook_crop_size,
+ do_rescale=codebook_do_rescale,
+ rescale_factor=codebook_rescale_factor,
+ do_normalize=codebook_do_normalize,
+ image_mean=codebook_image_mean,
+ image_std=codebook_image_std,
+ do_map_pixels=codebook_do_map_pixels,
+ data_format=data_format,
+ )
+ for img in images
+ ]
+ data["codebook_pixel_values"] = codebook_images
+
+ if return_image_mask:
+ mask_generator = self.masking_generator(
+ input_size_patches=input_size_patches,
+ total_mask_patches=total_mask_patches,
+ mask_group_min_patches=mask_group_min_patches,
+ mask_group_max_patches=mask_group_max_patches,
+ mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
+ mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
+ )
+ masks = [mask_generator() for _ in images]
+ data["bool_masked_pos"] = masks
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/src/transformers/models/glpn/image_processing_glpn.py b/src/transformers/models/glpn/image_processing_glpn.py
index 264cbd91f4..242d28f78e 100644
--- a/src/transformers/models/glpn/image_processing_glpn.py
+++ b/src/transformers/models/glpn/image_processing_glpn.py
@@ -37,16 +37,16 @@ class GLPNImageProcessor(BaseImageProcessor):
Args:
do_resize (`bool`, *optional*, defaults to `True`):
- Set the class default for the `do_resize` parameter. Controls whether to resize the image's (height, width)
- dimensions, rounding them down to the closest multiple of `size_divisor`.
+ Whether to resize the image's (height, width) dimensions, rounding them down to the closest multiple of
+ `size_divisor`. Can be overridden by `do_resize` in `preprocess`.
size_divisor (`int`, *optional*, defaults to 32):
- Set the class default for the `size_divisor` parameter. When `do_resize` is `True`, images are resized so
- their height and width are rounded down to the closest multiple of `size_divisor`.
+ When `do_resize` is `True`, images are resized so their height and width are rounded down to the closest
+ multiple of `size_divisor`. Can be overridden by `size_divisor` in `preprocess`.
resample (`PIL.Image` resampling filter, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
- Set the class default for `resample`. Defines the resampling filter to use if resizing the image.
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
do_rescale (`bool`, *optional*, defaults to `True`):
- Set the class default for the `do_rescale` parameter. Controls whether or not to apply the scaling factor
- (to make pixel values floats between 0. and 1.).
+ Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Can be
+ overridden by `do_rescale` in `preprocess`.
"""
model_input_names = ["pixel_values"]
@@ -81,7 +81,7 @@ class GLPNImageProcessor(BaseImageProcessor):
`size_divisor`.
resample:
`PIL.Image` resampling filter to use when resizing the image e.g. `PIL.Image.Resampling.BILINEAR`.
- data_format (`ChannelDimension`, *optional*):
+ data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If `None`, the channel dimension format of the input
image is used. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
@@ -108,7 +108,7 @@ class GLPNImageProcessor(BaseImageProcessor):
The image to rescale.
scale (`float`):
The scaling factor to rescale pixel values by.
- data_format (`ChannelDimension`, *optional*):
+ data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If `None`, the channel dimension format of the input
image is used. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
@@ -146,14 +146,14 @@ class GLPNImageProcessor(BaseImageProcessor):
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.).
- return_tensors (`str`, *optional*):
+ return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- `None`: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
- data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
diff --git a/src/transformers/models/imagegpt/feature_extraction_imagegpt.py b/src/transformers/models/imagegpt/feature_extraction_imagegpt.py
index 01628d2f3a..86a197aeae 100644
--- a/src/transformers/models/imagegpt/feature_extraction_imagegpt.py
+++ b/src/transformers/models/imagegpt/feature_extraction_imagegpt.py
@@ -14,168 +14,11 @@
# limitations under the License.
"""Feature extractor class for ImageGPT."""
-from typing import List, Optional, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
-from ...utils import TensorType, logging
+from ...utils import logging
+from .image_processing_imagegpt import ImageGPTImageProcessor
logger = logging.get_logger(__name__)
-def squared_euclidean_distance(a, b):
- b = b.T
- a2 = np.sum(np.square(a), axis=1)
- b2 = np.sum(np.square(b), axis=0)
- ab = np.matmul(a, b)
- d = a2[:, None] - 2 * ab + b2[None, :]
- return d
-
-
-def color_quantize(x, clusters):
- x = x.reshape(-1, 3)
- d = squared_euclidean_distance(x, clusters)
- return np.argmin(d, axis=1)
-
-
-class ImageGPTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs an ImageGPT feature extractor. This feature extractor can be used to resize images to a smaller
- resolution (such as 32x32 or 64x64), normalize them and finally color quantize them to obtain sequences of "pixel
- values" (color clusters).
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- clusters (`np.ndarray`):
- The color clusters to use, as a `np.ndarray` of shape `(n_clusters, 3)`.
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size (`int` or `Tuple(int)`, *optional*, defaults to 32):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
- set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input to the range between -1 and +1.
- """
-
- model_input_names = ["input_ids"]
-
- def __init__(
- self, clusters, do_resize=True, size=32, resample=PILImageResampling.BILINEAR, do_normalize=True, **kwargs
- ):
- super().__init__(**kwargs)
- self.clusters = np.asarray(clusters)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_normalize = do_normalize
-
- def normalize(self, image):
- """
- Normalizes `image` into the range -1 to +1.
-
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to normalize.
-
- Returns:
- `np.ndarray`: The normalized image.
- """
- image = self.to_numpy_array(image, rescale=False, channel_first=False)
-
- return image / 127.5 - 1
-
- def __call__(
- self,
- images: Union[
- Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
- ],
- return_tensors: Optional[Union[str, TensorType]] = None,
- **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **input_ids** -- Input IDs to be fed to a model, of shape `(batch_size, height * width)`.
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (resizing + normalization)
- if self.do_resize and self.size is not None:
- images = [self.resize(image, size=self.size, resample=self.resample) for image in images]
-
- if self.do_normalize:
- images = [self.normalize(image) for image in images]
-
- # color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
- images = np.array(images)
- images = color_quantize(images, self.clusters).reshape(images.shape[:-1])
-
- # flatten to (batch_size, height*width)
- batch_size = images.shape[0]
- images = images.reshape(batch_size, -1)
-
- # return as BatchFeature
- data = {"input_ids": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+ImageGPTFeatureExtractor = ImageGPTImageProcessor
diff --git a/src/transformers/models/imagegpt/image_processing_imagegpt.py b/src/transformers/models/imagegpt/image_processing_imagegpt.py
new file mode 100644
index 0000000000..af4db85670
--- /dev/null
+++ b/src/transformers/models/imagegpt/image_processing_imagegpt.py
@@ -0,0 +1,239 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for ImageGPT."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from transformers.utils import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import rescale, resize, to_channel_dimension_format
+from ...image_utils import ChannelDimension, ImageInput, PILImageResampling, is_batched, to_numpy_array, valid_images
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+def squared_euclidean_distance(a, b):
+ b = b.T
+ a2 = np.sum(np.square(a), axis=1)
+ b2 = np.sum(np.square(b), axis=0)
+ ab = np.matmul(a, b)
+ d = a2[:, None] - 2 * ab + b2[None, :]
+ return d
+
+
+def color_quantize(x, clusters):
+ x = x.reshape(-1, 3)
+ d = squared_euclidean_distance(x, clusters)
+ return np.argmin(d, axis=1)
+
+
+class ImageGPTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a ImageGPT image processor. This image processor can be used to resize images to a smaller resolution
+ (such as 32x32 or 64x64), normalize them and finally color quantize them to obtain sequences of "pixel values"
+ (color clusters).
+
+ Args:
+ clusters (`np.ndarray`, *optional*):
+ The color clusters to use, as a `np.ndarray` of shape `(n_clusters, 3)` when color quantizing. Can be
+ overriden by `clusters` in `preprocess`.
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's dimensions to `(size["height"], size["width"])`. Can be overridden by
+ `do_resize` in `preprocess`.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
+ Size of the image after resizing. Can be overridden by `size` in `preprocess`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image pixel value to between [-1, 1]. Can be overridden by `do_normalize` in
+ `preprocess`.
+ do_color_quantize (`bool`, *optional*, defaults to `True`):
+ Whether to color quantize the image. Can be overridden by `do_color_quantize` in `preprocess`.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ # clusters is a first argument to maintain backwards compatibility with the old ImageGPTFeatureExtractor
+ clusters: Optional[np.ndarray] = None,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_normalize: bool = True,
+ do_color_quantize: bool = True,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 256, "width": 256}
+ size = get_size_dict(size)
+ self.clusters = clusters
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_normalize = do_normalize
+ self.do_color_quantize = do_color_quantize
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image to (size["height"], size["width"]).
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use when resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"Size dictionary must contain both height and width keys. Got {size.keys()}")
+ return resize(
+ image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
+ )
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Normalizes an images' pixel values to between [-1, 1].
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ image = rescale(image=image, scale=1 / 127.5, data_format=data_format)
+ image = image - 1
+ return image
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_normalize: bool = None,
+ do_color_quantize: Optional[bool] = None,
+ clusters: Optional[Union[int, List[int]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image
+ do_color_quantize (`bool`, *optional*, defaults to `self.do_color_quantize`):
+ Whether to color quantize the image.
+ clusters (`np.ndarray`, *optional*, defaults to `self.clusters`):
+ Clusters used to quantize the image of shape `(n_clusters, 3)`. Only has an effect if
+ `do_color_quantize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ Only has an effect if `do_color_quantize` is set to `False`.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ resample = resample if resample is not None else self.resample
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ do_color_quantize = do_color_quantize if do_color_quantize is not None else self.do_color_quantize
+ clusters = clusters if clusters is not None else self.clusters
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None or resample is None:
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_color_quantize and clusters is None:
+ raise ValueError("Clusters must be specified if do_color_quantize is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_resize:
+ images = [self.resize(image=image, size=size, resample=resample) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image=image) for image in images]
+
+ if do_color_quantize:
+ images = [to_channel_dimension_format(image, ChannelDimension.LAST) for image in images]
+ # color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
+ images = np.array(images)
+ clusters = np.array(clusters)
+ images = color_quantize(images, clusters).reshape(images.shape[:-1])
+
+ # flatten to (batch_size, height*width)
+ batch_size = images.shape[0]
+ images = images.reshape(batch_size, -1)
+
+ # We need to convert back to a list of images to keep consistent behaviour across processors.
+ images = list(images)
+ else:
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = {"input_ids": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py
index 4d14fef4f0..9d89251841 100755
--- a/src/transformers/models/imagegpt/modeling_imagegpt.py
+++ b/src/transformers/models/imagegpt/modeling_imagegpt.py
@@ -992,11 +992,12 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
... )
>>> clusters = feature_extractor.clusters
- >>> n_px = feature_extractor.size
+ >>> height = feature_extractor.size["height"]
+ >>> width = feature_extractor.size["width"]
>>> samples = output[:, 1:].cpu().detach().numpy()
>>> samples_img = [
- ... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples
+ ... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
... ] # convert color cluster tokens back to pixels
>>> f, axes = plt.subplots(1, batch_size, dpi=300)
diff --git a/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py b/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
index 8adf8b7911..606e708381 100644
--- a/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
+++ b/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py
@@ -16,226 +16,10 @@
Feature extractor class for LayoutLMv2.
"""
-from typing import List, Optional, Union
+from ...utils import logging
+from .image_processing_layoutlmv2 import LayoutLMv2ImageProcessor
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
-from ...utils import TensorType, is_pytesseract_available, logging, requires_backends
-
-
-# soft dependency
-if is_pytesseract_available():
- import pytesseract
logger = logging.get_logger(__name__)
-ImageInput = Union[
- Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
-]
-
-
-def normalize_box(box, width, height):
- return [
- int(1000 * (box[0] / width)),
- int(1000 * (box[1] / height)),
- int(1000 * (box[2] / width)),
- int(1000 * (box[3] / height)),
- ]
-
-
-def apply_tesseract(image: Image.Image, lang: Optional[str], tesseract_config: Optional[str]):
- """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
-
- # apply OCR
- data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
- words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
-
- # filter empty words and corresponding coordinates
- irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
- words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
- left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
- top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
- width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
- height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
-
- # turn coordinates into (left, top, left+width, top+height) format
- actual_boxes = []
- for x, y, w, h in zip(left, top, width, height):
- actual_box = [x, y, x + w, y + h]
- actual_boxes.append(actual_box)
-
- image_width, image_height = image.size
-
- # finally, normalize the bounding boxes
- normalized_boxes = []
- for box in actual_boxes:
- normalized_boxes.append(normalize_box(box, image_width, image_height))
-
- assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
-
- return words, normalized_boxes
-
-
-class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a LayoutLMv2 feature extractor. This can be used to resize document images to the same size, as well as
- to apply OCR on them in order to get a list of words and normalized bounding boxes.
-
- This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most
- of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size (`int` or `Tuple(int)`, *optional*, defaults to 224):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
- set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- apply_ocr (`bool`, *optional*, defaults to `True`):
- Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
- ocr_lang (`str`, *optional*):
- The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
- used.
- tesseract_config (`str`, *optional*):
- Any additional custom configuration flags that are forwarded to the `config` parameter when calling
- Tesseract. For example: '--psm 6'.
-
-
-
- LayoutLMv2FeatureExtractor uses Google's Tesseract OCR engine under the hood.
-
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=224,
- resample=PILImageResampling.BILINEAR,
- apply_ocr=True,
- ocr_lang=None,
- tesseract_config="",
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.apply_ocr = apply_ocr
- self.ocr_lang = ocr_lang
- self.tesseract_config = tesseract_config
-
- def __call__(
- self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- - **words** -- Optional words as identified by Tesseract OCR (only when [`LayoutLMv2FeatureExtractor`] was
- initialized with `apply_ocr` set to `True`).
- - **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size
- (only when [`LayoutLMv2FeatureExtractor`] was initialized with `apply_ocr` set to `True`).
-
- Examples:
-
- ```python
- >>> from transformers import LayoutLMv2FeatureExtractor
- >>> from PIL import Image
-
- >>> # Document can be a png, jpg, etc. PDFs must be converted to images.
- >>> image = Image.open(name_of_your_document).convert("RGB")
-
- >>> # option 1: with apply_ocr=True (default)
- >>> feature_extractor = LayoutLMv2FeatureExtractor()
- >>> encoding = feature_extractor(image, return_tensors="pt")
- >>> print(encoding.keys())
- >>> # dict_keys(['pixel_values', 'words', 'boxes'])
-
- >>> # option 2: with apply_ocr=False
- >>> feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
- >>> encoding = feature_extractor(image, return_tensors="pt")
- >>> print(encoding.keys())
- >>> # dict_keys(['pixel_values'])
- ```"""
-
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), "
- f"but is of type {type(images)}."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # Tesseract OCR to get words + normalized bounding boxes
- if self.apply_ocr:
- requires_backends(self, "pytesseract")
- words_batch = []
- boxes_batch = []
- for image in images:
- words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang, self.tesseract_config)
- words_batch.append(words)
- boxes_batch.append(boxes)
-
- # transformations (resizing)
- if self.do_resize and self.size is not None:
- images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
-
- images = [self.to_numpy_array(image, rescale=False) for image in images]
- # flip color channels from RGB to BGR (as Detectron2 requires this)
- images = [image[::-1, :, :] for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- if self.apply_ocr:
- encoded_inputs["words"] = words_batch
- encoded_inputs["boxes"] = boxes_batch
-
- return encoded_inputs
+LayoutLMv2FeatureExtractor = LayoutLMv2ImageProcessor
diff --git a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py
new file mode 100644
index 0000000000..fb3e2c5a01
--- /dev/null
+++ b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py
@@ -0,0 +1,268 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for LayoutLMv2."""
+
+from typing import Dict, Optional, Union
+
+import numpy as np
+
+from transformers.utils import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import resize, to_channel_dimension_format, to_pil_image
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import is_pytesseract_available, logging, requires_backends
+
+
+if is_vision_available():
+ import PIL
+
+# soft dependency
+if is_pytesseract_available():
+ import pytesseract
+
+logger = logging.get_logger(__name__)
+
+
+def normalize_box(box, width, height):
+ return [
+ int(1000 * (box[0] / width)),
+ int(1000 * (box[1] / height)),
+ int(1000 * (box[2] / width)),
+ int(1000 * (box[3] / height)),
+ ]
+
+
+def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Optional[str] = None):
+ """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
+ tesseract_config = tesseract_config if tesseract_config is not None else ""
+
+ # apply OCR
+ pil_image = to_pil_image(image)
+ image_width, image_height = pil_image.size
+ data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
+ words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
+
+ # filter empty words and corresponding coordinates
+ irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
+ words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
+ left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
+ top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
+ width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
+ height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
+
+ # turn coordinates into (left, top, left+width, top+height) format
+ actual_boxes = []
+ for x, y, w, h in zip(left, top, width, height):
+ actual_box = [x, y, x + w, y + h]
+ actual_boxes.append(actual_box)
+
+ # finally, normalize the bounding boxes
+ normalized_boxes = []
+ for box in actual_boxes:
+ normalized_boxes.append(normalize_box(box, image_width, image_height))
+
+ assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
+
+ return words, normalized_boxes
+
+
+def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
+ input_data_format = infer_channel_dimension_format(image)
+ if input_data_format == ChannelDimension.LAST:
+ image = image[..., ::-1]
+ elif input_data_format == ChannelDimension.FIRST:
+ image = image[:, ::-1, ...]
+ else:
+ raise ValueError(f"Unsupported channel dimension: {input_data_format}")
+
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format)
+ return image
+
+
+class LayoutLMv2ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a LayoutLMv2 image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to `(size["height"], size["width"])`. Can be
+ overridden by `do_resize` in `preprocess`.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the image after resizing. Can be overridden by `size` in `preprocess`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ apply_ocr (`bool`, *optional*, defaults to `True`):
+ Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. Can be overridden by
+ `apply_ocr` in `preprocess`.
+ ocr_lang (`str`, *optional*):
+ The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
+ used. Can be overridden by `ocr_lang` in `preprocess`.
+ tesseract_config (`str`, *optional*):
+ Any additional custom configuration flags that are forwarded to the `config` parameter when calling
+ Tesseract. For example: '--psm 6'. Can be overridden by `tesseract_config` in `preprocess`.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ apply_ocr: bool = True,
+ ocr_lang: Optional[str] = None,
+ tesseract_config: Optional[str] = "",
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 224, "width": 224}
+ size = get_size_dict(size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.apply_ocr = apply_ocr
+ self.ocr_lang = ocr_lang
+ self.tesseract_config = tesseract_config
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use when resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+ return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ apply_ocr: bool = None,
+ ocr_lang: Optional[str] = None,
+ tesseract_config: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Desired size of the output image after resizing.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PIL.Image` resampling
+ filter. Only has an effect if `do_resize` is set to `True`.
+ apply_ocr (`bool`, *optional*, defaults to `self.apply_ocr`):
+ Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
+ ocr_lang (`str`, *optional*, defaults to `self.ocr_lang`):
+ The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
+ used.
+ tesseract_config (`str`, *optional*, defaults to `self.tesseract_config`):
+ Any additional custom configuration flags that are forwarded to the `config` parameter when calling
+ Tesseract.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ resample = resample if resample is not None else self.resample
+ apply_ocr = apply_ocr if apply_ocr is not None else self.apply_ocr
+ ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang
+ tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None:
+ raise ValueError("Size must be specified if do_resize is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if apply_ocr:
+ requires_backends(self, "pytesseract")
+ words_batch = []
+ boxes_batch = []
+ for image in images:
+ words, boxes = apply_tesseract(image, ocr_lang, tesseract_config)
+ words_batch.append(words)
+ boxes_batch.append(boxes)
+
+ if do_resize:
+ images = [self.resize(image=image, size=size, resample=resample) for image in images]
+
+ # flip color channels from RGB to BGR (as Detectron2 requires this)
+ images = [flip_channel_order(image) for image in images]
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+
+ if apply_ocr:
+ data["words"] = words_batch
+ data["boxes"] = boxes_batch
+ return data
diff --git a/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py b/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py
index 1944a79ffb..d742c068fc 100644
--- a/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py
+++ b/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py
@@ -16,235 +16,11 @@
Feature extractor class for LayoutLMv3.
"""
-from typing import List, Optional, Union
+from ...utils import logging
+from .image_processing_layoutlmv3 import LayoutLMv3ImageProcessor
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageFeatureExtractionMixin, is_torch_tensor
-from ...utils import TensorType, is_pytesseract_available, logging, requires_backends
-
-
-# soft dependency
-if is_pytesseract_available():
- import pytesseract
logger = logging.get_logger(__name__)
-ImageInput = Union[
- Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
-]
-
-def normalize_box(box, width, height):
- return [
- int(1000 * (box[0] / width)),
- int(1000 * (box[1] / height)),
- int(1000 * (box[2] / width)),
- int(1000 * (box[3] / height)),
- ]
-
-
-def apply_tesseract(image: Image.Image, lang: Optional[str], tesseract_config: Optional[str]):
- """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
- # apply OCR
- data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
- words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
-
- # filter empty words and corresponding coordinates
- irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
- words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
- left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
- top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
- width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
- height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
-
- # turn coordinates into (left, top, left+width, top+height) format
- actual_boxes = []
- for x, y, w, h in zip(left, top, width, height):
- actual_box = [x, y, x + w, y + h]
- actual_boxes.append(actual_box)
-
- image_width, image_height = image.size
-
- # finally, normalize the bounding boxes
- normalized_boxes = []
- for box in actual_boxes:
- normalized_boxes.append(normalize_box(box, image_width, image_height))
-
- assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
-
- return words, normalized_boxes
-
-
-class LayoutLMv3FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a LayoutLMv3 feature extractor. This can be used to resize + normalize document images, as well as to
- apply OCR on them in order to get a list of words and normalized bounding boxes.
-
- This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most
- of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size (`int` or `Tuple(int)`, *optional*, defaults to 224):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
- set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with mean and standard deviation.
- image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- apply_ocr (`bool`, *optional*, defaults to `True`):
- Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
- ocr_lang (`str`, *optional*):
- The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
- used.
- tesseract_config (`str`, *optional*):
- Any additional custom configuration flags that are forwarded to the `config` parameter when calling
- Tesseract. For example: '--psm 6'.
-
-
-
- LayoutLMv3FeatureExtractor uses Google's Tesseract OCR engine under the hood.
-
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=224,
- resample=PILImageResampling.BILINEAR,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- apply_ocr=True,
- ocr_lang=None,
- tesseract_config="",
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
- self.apply_ocr = apply_ocr
- self.ocr_lang = ocr_lang
- self.tesseract_config = tesseract_config
-
- def __call__(
- self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- - **words** -- Optional words as identified by Tesseract OCR (only when [`LayoutLMv3FeatureExtractor`] was
- initialized with `apply_ocr` set to `True`).
- - **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size
- (only when [`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`).
-
- Examples:
-
- ```python
- >>> from transformers import LayoutLMv3FeatureExtractor
- >>> from PIL import Image
-
- >>> # Document can be a png, jpg, etc. PDFs must be converted to images.
- >>> image = Image.open(name_of_your_document).convert("RGB")
-
- >>> # option 1: with apply_ocr=True (default)
- >>> feature_extractor = LayoutLMv3FeatureExtractor()
- >>> encoding = feature_extractor(image, return_tensors="pt")
- >>> print(encoding.keys())
- >>> # dict_keys(['pixel_values', 'words', 'boxes'])
-
- >>> # option 2: with apply_ocr=False
- >>> feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
- >>> encoding = feature_extractor(image, return_tensors="pt")
- >>> print(encoding.keys())
- >>> # dict_keys(['pixel_values'])
- ```"""
-
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), "
- f"but is of type {type(images)}."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # Tesseract OCR to get words + normalized bounding boxes
- if self.apply_ocr:
- requires_backends(self, "pytesseract")
- words_batch = []
- boxes_batch = []
- for image in images:
- words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang, self.tesseract_config)
- words_batch.append(words)
- boxes_batch.append(boxes)
-
- # transformations (resizing + normalization)
- if self.do_resize and self.size is not None:
- images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- if self.apply_ocr:
- encoded_inputs["words"] = words_batch
- encoded_inputs["boxes"] = boxes_batch
-
- return encoded_inputs
+LayoutLMv3FeatureExtractor = LayoutLMv3ImageProcessor
diff --git a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py
new file mode 100644
index 0000000000..d18171610a
--- /dev/null
+++ b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py
@@ -0,0 +1,371 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for LayoutLMv3."""
+
+from typing import Dict, Iterable, Optional, Union
+
+import numpy as np
+
+from transformers.utils import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format, to_pil_image
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import is_pytesseract_available, logging, requires_backends
+
+
+if is_vision_available():
+ import PIL
+
+# soft dependency
+if is_pytesseract_available():
+ import pytesseract
+
+logger = logging.get_logger(__name__)
+
+
+def normalize_box(box, width, height):
+ return [
+ int(1000 * (box[0] / width)),
+ int(1000 * (box[1] / height)),
+ int(1000 * (box[2] / width)),
+ int(1000 * (box[3] / height)),
+ ]
+
+
+def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Optional[str]):
+ """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
+
+ # apply OCR
+ pil_image = to_pil_image(image)
+ image_width, image_height = pil_image.size
+ data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
+ words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
+
+ # filter empty words and corresponding coordinates
+ irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
+ words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
+ left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
+ top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
+ width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
+ height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
+
+ # turn coordinates into (left, top, left+width, top+height) format
+ actual_boxes = []
+ for x, y, w, h in zip(left, top, width, height):
+ actual_box = [x, y, x + w, y + h]
+ actual_boxes.append(actual_box)
+
+ # finally, normalize the bounding boxes
+ normalized_boxes = []
+ for box in actual_boxes:
+ normalized_boxes.append(normalize_box(box, image_width, image_height))
+
+ assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
+
+ return words, normalized_boxes
+
+
+def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
+ input_data_format = infer_channel_dimension_format(image)
+ if input_data_format == ChannelDimension.LAST:
+ image = image[..., ::-1]
+ elif input_data_format == ChannelDimension.FIRST:
+ image = image[:, ::-1, ...]
+ else:
+ raise ValueError(f"Unsupported channel dimension: {input_data_format}")
+
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format)
+ return image
+
+
+class LayoutLMv3ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a LayoutLMv3 image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to `(size["height"], size["width"])`. Can be
+ overridden by `do_resize` in `preprocess`.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the image after resizing. Can be overridden by `size` in `preprocess`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image's pixel values by the specified `rescale_value`. Can be overridden by
+ `do_rescale` in `preprocess`.
+ rescale_factor (`float`, *optional*, defaults to 1 / 255):
+ Value by which the image's pixel values are rescaled. Can be overridden by `rescale_factor` in
+ `preprocess`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`Iterable[float]` or `float`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`Iterable[float]` or `float`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ apply_ocr (`bool`, *optional*, defaults to `True`):
+ Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. Can be overridden by
+ the `apply_ocr` parameter in the `preprocess` method.
+ ocr_lang (`str`, *optional*):
+ The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
+ used. Can be overridden by the `ocr_lang` parameter in the `preprocess` method.
+ tesseract_config (`str`, *optional*):
+ Any additional custom configuration flags that are forwarded to the `config` parameter when calling
+ Tesseract. For example: '--psm 6'. Can be overridden by the `tesseract_config` parameter in the
+ `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_value: float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Union[float, Iterable[float]] = None,
+ image_std: Union[float, Iterable[float]] = None,
+ apply_ocr: bool = True,
+ ocr_lang: Optional[str] = None,
+ tesseract_config: Optional[str] = "",
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 224, "width": 224}
+ size = get_size_dict(size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_value
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.apply_ocr = apply_ocr
+ self.ocr_lang = ocr_lang
+ self.tesseract_config = tesseract_config
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image to (size["height"], size["width"]) dimensions.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+ return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, Iterable[float]],
+ std: Union[float, Iterable[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ mean (`float` or `Iterable[float]`):
+ Mean values to be used for normalization.
+ std (`float` or `Iterable[float]`):
+ Standard deviation values to be used for normalization.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample=None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Union[float, Iterable[float]] = None,
+ image_std: Union[float, Iterable[float]] = None,
+ apply_ocr: bool = None,
+ ocr_lang: Optional[str] = None,
+ tesseract_config: Optional[str] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Desired size of the output image after applying `resize`.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` filters.
+ Only has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image pixel values between [0, 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to apply to the image pixel values. Only has an effect if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `Iterable[float]`, *optional*, defaults to `self.image_mean`):
+ Mean values to be used for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `Iterable[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation values to be used for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ apply_ocr (`bool`, *optional*, defaults to `self.apply_ocr`):
+ Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
+ ocr_lang (`str`, *optional*, defaults to `self.ocr_lang`):
+ The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
+ used.
+ tesseract_config (`str`, *optional*, defaults to `self.tesseract_config`):
+ Any additional custom configuration flags that are forwarded to the `config` parameter when calling
+ Tesseract.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ apply_ocr = apply_ocr if apply_ocr is not None else self.apply_ocr
+ ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang
+ tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None:
+ raise ValueError("Size must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("If do_normalize is True, image_mean and image_std must be specified.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ # Tesseract OCR to get words + normalized bounding boxes
+ if apply_ocr:
+ requires_backends(self, "pytesseract")
+ words_batch = []
+ boxes_batch = []
+ for image in images:
+ words, boxes = apply_tesseract(image, ocr_lang, tesseract_config)
+ words_batch.append(words)
+ boxes_batch.append(boxes)
+
+ if do_resize:
+ images = [self.resize(image=image, size=size, resample=resample) for image in images]
+
+ if do_rescale:
+ images = [self.rescale(image=image, scale=rescale_factor) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
+
+ # flip color channels from RGB to BGR (as Detectron2 requires this)
+ images = [flip_channel_order(image) for image in images]
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+
+ if apply_ocr:
+ data["words"] = words_batch
+ data["boxes"] = boxes_batch
+ return data
diff --git a/src/transformers/models/levit/feature_extraction_levit.py b/src/transformers/models/levit/feature_extraction_levit.py
index a4e359bc81..c282c73d7b 100644
--- a/src/transformers/models/levit/feature_extraction_levit.py
+++ b/src/transformers/models/levit/feature_extraction_levit.py
@@ -14,148 +14,12 @@
# limitations under the License.
"""Feature extractor class for LeViT."""
-from typing import Optional, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import (
- IMAGENET_DEFAULT_MEAN,
- IMAGENET_DEFAULT_STD,
- ImageFeatureExtractionMixin,
- ImageInput,
- is_torch_tensor,
-)
-from ...utils import TensorType, logging
+from ...utils import logging
+from .image_processing_levit import LevitImageProcessor
logger = logging.get_logger(__name__)
-class LevitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a LeViT feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the shortest edge of the input to int(256/224 *`size`).
- size (`int` or `Tuple(int)`, *optional*, defaults to 224):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then shorter side of input will be resized to 'size'.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_center_crop (`bool`, *optional*, defaults to `True`):
- Whether or not to center crop the input to `size`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with mean and standard deviation.
- image_mean (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=224,
- resample=PILImageResampling.BICUBIC,
- do_center_crop=True,
- do_normalize=True,
- image_mean=IMAGENET_DEFAULT_MEAN,
- image_std=IMAGENET_DEFAULT_STD,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_center_crop = do_center_crop
- self.do_normalize = do_normalize
- self.image_mean = image_mean
- self.image_std = image_std
-
- def __call__(
- self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (resizing + center cropping + normalization)
- if self.do_resize and self.size is not None:
- size_ = int((256 / 224) * self.size)
- images = [
- self.resize(image=image, size=size_, resample=self.resample, default_to_square=False)
- for image in images
- ]
- if self.do_center_crop:
- images = [self.center_crop(image=image, size=self.size) for image in images]
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+# Feature extractor for Levit is being replaced by image processor
+LevitFeatureExtractor = LevitImageProcessor
diff --git a/src/transformers/models/levit/image_processing_levit.py b/src/transformers/models/levit/image_processing_levit.py
new file mode 100644
index 0000000000..4917583357
--- /dev/null
+++ b/src/transformers/models/levit/image_processing_levit.py
@@ -0,0 +1,342 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for LeViT."""
+
+from typing import Dict, Iterable, List, Optional, Union
+
+import numpy as np
+
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ center_crop,
+ get_resize_output_image_size,
+ normalize,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class LevitImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a LeViT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Wwhether to resize the shortest edge of the input to int(256/224 *`size`). Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`Dict[str, int]`, *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the output image after resizing. If size is a dict with keys "width" and "height", the image will
+ be resized to `(size["height"], size["width"])`. If size is a dict with key "shortest_edge", the shortest
+ edge value `c` is rescaled to `int(c * (256/224))`. The smaller edge of the image will be matched to this
+ value i.e, if height > width, then image will be rescaled to `(size["shortest_egde"] * height / width,
+ size["shortest_egde"])`. Can be overridden by the `size` parameter in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether or not to center crop the input to `(crop_size["height"], crop_size["width"])`. Can be overridden
+ by the `do_center_crop` parameter in the `preprocess` method.
+ crop_size (`Dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Desired image size after `center_crop`. Can be overridden by the `crop_size` parameter in the `preprocess`
+ method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
+ `preprocess` method.
+ image_mean (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_MEAN,
+ image_std: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_STD,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ If size is a dict with keys "width" and "height", the image will be resized to `(size["height"],
+ size["width"])`.
+
+ If size is a dict with key "shortest_edge", the shortest edge value `c` is rescaled to `int(c * (256/224))`.
+ The smaller edge of the image will be matched to this value i.e, if height > width, then image will be rescaled
+ to `(size["shortest_egde"] * height / width, size["shortest_egde"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image after resizing. If size is a dict with keys "width" and "height", the image
+ will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value
+ `c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value
+ i.e, if height > width, then image will be rescaled to (size * height / width, size).
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size_dict = get_size_dict(size, default_to_square=False)
+ # size_dict is a dict with either keys "height" and "width" or "shortest_edge"
+ if "shortest_edge" in size:
+ shortest_edge = int((256 / 224) * size["shortest_edge"])
+ output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
+ size_dict = {"height": output_size[0], "width": output_size[1]}
+ if "height" not in size_dict or "width" not in size_dict:
+ raise ValueError(
+ f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size_dict.keys()}"
+ )
+ return resize(
+ image, size=(size_dict["height"], size_dict["width"]), resample=resample, data_format=data_format, **kwargs
+ )
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Center crop an image.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Dict `{"height": int, "width": int}` specifying the size of the output image after cropping.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ mean (`float` or `List[float]`):
+ Image mean.
+ std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[Dict[str, int]] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, Iterable[float]]] = None,
+ image_std: Optional[Union[float, Iterable[float]]] = None,
+ return_tensors: Optional[TensorType] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> BatchFeature:
+ """
+ Preprocess an image or batch of images to be used as input to a LeViT model.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the output image after resizing. If size is a dict with keys "width" and "height", the image
+ will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value
+ `c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value
+ i.e, if height > width, then image will be rescaled to (size * height / width, size).
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the output image after center cropping. Crops images to (crop_size["height"],
+ crop_size["width"]).
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image pixel values by `rescaling_factor` - typical to values between 0 and 1.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Factor to rescale the image pixel values by.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image pixel values by `image_mean` and `image_std`.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to normalize the image pixel values by.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to normalize the image pixel values by.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size)
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None:
+ raise ValueError("Size must be specified if do_resize is True.")
+
+ if do_center_crop and crop_size is None:
+ raise ValueError("Crop size must be specified if do_center_crop is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_resize:
+ images = [self.resize(image, size, resample) for image in images]
+
+ if do_center_crop:
+ images = [self.center_crop(image, crop_size) for image in images]
+
+ if do_rescale:
+ images = [self.rescale(image, rescale_factor) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image, image_mean, image_std) for image in images]
+
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/src/transformers/models/mobilevit/feature_extraction_mobilevit.py b/src/transformers/models/mobilevit/feature_extraction_mobilevit.py
index 5c38b8e240..d546bb2218 100644
--- a/src/transformers/models/mobilevit/feature_extraction_mobilevit.py
+++ b/src/transformers/models/mobilevit/feature_extraction_mobilevit.py
@@ -14,189 +14,11 @@
# limitations under the License.
"""Feature extractor class for MobileViT."""
-from typing import List, Optional, Tuple, Union
+from ...utils import logging
+from .image_processing_mobilevit import MobileViTImageProcessor
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
-from ...utils import TensorType, is_torch_available, logging
-
-
-if is_torch_available():
- import torch
logger = logging.get_logger(__name__)
-class MobileViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a MobileViT feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size (`int` or `Tuple(int)`, *optional*, defaults to 288):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be resized to match the shorter side. Only has an effect if
- `do_resize` is set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_center_crop (`bool`, *optional*, defaults to `True`):
- Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
- image is padded with 0's and then center cropped.
- crop_size (`int`, *optional*, defaults to 256):
- Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
- do_flip_channel_order (`bool`, *optional*, defaults to `True`):
- Whether to flip the color channels from RGB to BGR.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=288,
- resample=PILImageResampling.BILINEAR,
- do_center_crop=True,
- crop_size=256,
- do_flip_channel_order=True,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_center_crop = do_center_crop
- self.crop_size = crop_size
- self.do_flip_channel_order = do_flip_channel_order
-
- def __call__(
- self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (resizing + normalization)
- if self.do_resize and self.size is not None:
- images = [
- self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False)
- for image in images
- ]
- if self.do_center_crop and self.crop_size is not None:
- images = [self.center_crop(image, self.crop_size) for image in images]
-
- images = [self.to_numpy_array(image) for image in images]
-
- # the pretrained checkpoints assume images are BGR, not RGB
- if self.do_flip_channel_order:
- images = [self.flip_channel_order(image) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
-
- def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
- """
- Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports
- PyTorch.
-
- Args:
- outputs ([`MobileViTForSemanticSegmentation`]):
- Raw outputs of the model.
- target_sizes (`List[Tuple]`, *optional*):
- A list of length `batch_size`, where each item is a `Tuple[int, int]` corresponding to the requested
- final size (height, width) of each prediction. If left to None, predictions will not be resized.
- Returns:
- `List[torch.Tensor]`:
- A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
- corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
- `torch.Tensor` correspond to a semantic class id.
- """
- logits = outputs.logits
-
- # Resize logits and compute semantic segmentation maps
- if target_sizes is not None:
- if len(logits) != len(target_sizes):
- raise ValueError(
- "Make sure that you pass in as many target sizes as the batch dimension of the logits"
- )
-
- if is_torch_tensor(target_sizes):
- target_sizes = target_sizes.numpy()
-
- semantic_segmentation = []
-
- for idx in range(len(logits)):
- resized_logits = torch.nn.functional.interpolate(
- logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
- )
- semantic_map = resized_logits[0].argmax(dim=0)
- semantic_segmentation.append(semantic_map)
- else:
- semantic_segmentation = logits.argmax(dim=1)
- semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
-
- return semantic_segmentation
+MobileViTFeatureExtractor = MobileViTImageProcessor
diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit.py b/src/transformers/models/mobilevit/image_processing_mobilevit.py
new file mode 100644
index 0000000000..fc017c2ccf
--- /dev/null
+++ b/src/transformers/models/mobilevit/image_processing_mobilevit.py
@@ -0,0 +1,364 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for MobileViT."""
+
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+
+from transformers.utils import is_torch_available, is_torch_tensor, is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import center_crop, get_resize_output_image_size, rescale, resize, to_channel_dimension_format
+from ...image_utils import (
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension]) -> np.ndarray:
+ """
+ Flip the color channels from RGB to BGR or vice versa.
+
+ Args:
+ image (`np.ndarray`):
+ The image, represented as a numpy array.
+ data_format (`ChannelDimension`, *`optional`*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+
+ Returns:
+ `np.ndarray`: The image with the flipped color channels.
+ """
+ input_data_format = infer_channel_dimension_format(image)
+
+ if input_data_format == ChannelDimension.LAST:
+ image = image[..., ::-1]
+ elif input_data_format == ChannelDimension.FIRST:
+ image = image[:, ::-1, ...]
+ else:
+ raise ValueError(f"Invalid input channel dimension format: {input_data_format}")
+
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format)
+
+ return image
+
+
+class MobileViTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a MobileViT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Controls the size of the output image after resizing. Can be overridden by the `size` parameter in the
+ `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter
+ in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
+ image is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in
+ the `preprocess` method.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
+ Desired output size `(size["height"], size["width"])` when applying center-cropping. Can be overridden by
+ the `crop_size` parameter in the `preprocess` method.
+ do_flip_channel_order (`bool`, *optional*, defaults to `True`):
+ Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order`
+ parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ do_flip_channel_order: bool = True,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
+ crop_size = get_size_dict(crop_size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_flip_channel_order = do_flip_channel_order
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PIL.Image.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Controls the size of the output image. The shortest edge of the image will be resized to
+ `size["shortest_edge"]` while maintaining the aspect ratio.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size, default_to_square=False)
+ if "shortest_edge" not in size:
+ raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
+ output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
+ return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Center crop an image to size `(size["height], size["width"])`. If the input size is smaller than `size` along
+ any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def flip_channel_order(
+ self, image: np.ndarray, data_format: Optional[Union[str, ChannelDimension]] = None
+ ) -> np.ndarray:
+ """
+ Flip the color channels from RGB to BGR or vice versa.
+
+ Args:
+ image (`np.ndarray`):
+ The image, represented as a numpy array.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return flip_channel_order(image, data_format=data_format)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_center_crop: bool = None,
+ crop_size: Dict[str, int] = None,
+ do_flip_channel_order: bool = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image by rescale factor.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop if `do_center_crop` is set to `True`.
+ do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
+ Whether to flip the channel order of the image.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_flip_channel_order = (
+ do_flip_channel_order if do_flip_channel_order is not None else self.do_flip_channel_order
+ )
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size)
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None:
+ raise ValueError("Size must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_center_crop and crop_size is None:
+ raise ValueError("Crop size must be specified if do_center_crop is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_resize:
+ images = [self.resize(image=image, size=size, resample=resample) for image in images]
+
+ if do_center_crop:
+ images = [self.center_crop(image=image, size=crop_size) for image in images]
+
+ if do_rescale:
+ images = [self.rescale(image=image, scale=rescale_factor) for image in images]
+
+ # the pretrained checkpoints assume images are BGR, not RGB
+ if do_flip_channel_order:
+ images = [self.flip_channel_order(image=image) for image in images]
+
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
+ """
+ Args:
+ Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports
+ PyTorch.
+ outputs ([`MobileViTForSemanticSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`List[Tuple]`, *optional*):
+ A list of length `batch_size`, where each item is a `Tuple[int, int]` corresponding to the requested
+ final size (height, width) of each prediction. If left to None, predictions will not be resized.
+ Returns:
+ `List[torch.Tensor]`:
+ A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
+ corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
+ `torch.Tensor` correspond to a semantic class id.
+ """
+ # TODO: add support for other frameworks
+ logits = outputs.logits
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if len(logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ if is_torch_tensor(target_sizes):
+ target_sizes = target_sizes.numpy()
+
+ semantic_segmentation = []
+
+ for idx in range(len(logits)):
+ resized_logits = torch.nn.functional.interpolate(
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = logits.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
diff --git a/src/transformers/models/perceiver/feature_extraction_perceiver.py b/src/transformers/models/perceiver/feature_extraction_perceiver.py
index dfa8b13304..ee0af76eed 100644
--- a/src/transformers/models/perceiver/feature_extraction_perceiver.py
+++ b/src/transformers/models/perceiver/feature_extraction_perceiver.py
@@ -14,179 +14,11 @@
# limitations under the License.
"""Feature extractor class for Perceiver."""
-from typing import Optional, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import (
- IMAGENET_DEFAULT_MEAN,
- IMAGENET_DEFAULT_STD,
- ImageFeatureExtractionMixin,
- ImageInput,
- is_torch_tensor,
-)
-from ...utils import TensorType, logging
+from ...utils import logging
+from .image_processing_perceiver import PerceiverImageProcessor
logger = logging.get_logger(__name__)
-class PerceiverFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a Perceiver feature extractor.
-
- This feature extractor inherits from [`ImageFeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_center_crop (`bool`, *optional*, defaults to `True`):
- Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
- image is padded with 0's and then center cropped.
- crop_size (`int`, *optional*, defaults to 256):
- Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size (`int` or `Tuple(int)`, *optional*, defaults to 224):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
- set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with `image_mean` and `image_std`.
- image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_center_crop=True,
- crop_size=256,
- do_resize=True,
- size=224,
- resample=PILImageResampling.BICUBIC,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_center_crop = do_center_crop
- self.crop_size = crop_size
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
-
- def center_crop(self, image):
- """
- Crops `image` to *self.crop_size* using a center crop. Note that if the image is too small to be cropped to the
- size given, it will be padded (so the returned result has the size asked).
-
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to resize.
- """
-
- if isinstance(image, Image.Image):
- image = self.to_numpy_array(image)
-
- image_height, image_width = image.shape[-2:]
-
- padded_center_crop_size = (
- (self.size / (self.crop_size)) * np.minimum(image_height, image_width).astype(np.float32)
- ).astype(np.int32)
-
- offset_height = ((image_height - padded_center_crop_size) + 1) // 2
- offset_width = ((image_width - padded_center_crop_size) + 1) // 2
- crop_window = [offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]
-
- image = image[
- :, crop_window[0] : crop_window[0] + crop_window[2], crop_window[1] : crop_window[1] + crop_window[3]
- ]
-
- return image
-
- def __call__(
- self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (center cropping + resizing + normalization)
- if self.do_center_crop and self.crop_size is not None:
- images = [self.center_crop(image) for image in images]
- if self.do_resize and self.size is not None and self.resample is not None:
- images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+PerceiverFeatureExtractor = PerceiverImageProcessor
diff --git a/src/transformers/models/perceiver/image_processing_perceiver.py b/src/transformers/models/perceiver/image_processing_perceiver.py
new file mode 100644
index 0000000000..dbbfc3913a
--- /dev/null
+++ b/src/transformers/models/perceiver/image_processing_perceiver.py
@@ -0,0 +1,330 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Perceiver."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from transformers.utils import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class PerceiverImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Perceiver image processor.
+
+ Args:
+ do_center_crop (`bool`, `optional`, defaults to `True`):
+ Whether or not to center crop the image. If the input size if smaller than `crop_size` along any edge, the
+ image will be padded with zeros and then center cropped. Can be overridden by the `do_center_crop`
+ parameter in the `preprocess` method.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
+ Desired output size when applying center-cropping. Can be overridden by the `crop_size` parameter in the
+ `preprocess` method.
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image to `(size["height"], size["width"])`. Can be overridden by the `do_resize`
+ parameter in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the image after resizing. Can be overridden by the `size` parameter in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter
+ in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter
+ in the `preprocess` method.
+ do_normalize:
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
+ crop_size = get_size_dict(crop_size)
+ size = size if size is not None else {"height": 224, "width": 224}
+ size = get_size_dict(size)
+
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ crop_size: Dict[str, int],
+ size: Optional[int] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Center crop an image to `(size["height"] / crop_size["height"] * min_dim, size["width"] / crop_size["width"] *
+ min_dim)`. Where `min_dim = min(size["height"], size["width"])`.
+
+ If the input size is smaller than `crop_size` along any edge, the image will be padded with zeros and then
+ center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ crop_size (`Dict[str, int]`):
+ Desired output size after applying the center crop.
+ size (`Dict[str, int]`, *optional*):
+ Size of the image after resizing. If not provided, the self.size attribute will be used.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = self.size if size is None else size
+ size = get_size_dict(size)
+ crop_size = get_size_dict(crop_size)
+
+ height, width = get_image_size(image)
+ min_dim = min(height, width)
+ cropped_height = (size["height"] / crop_size["height"]) * min_dim
+ cropped_width = (size["width"] / crop_size["width"]) * min_dim
+ return center_crop(image, size=(cropped_height, cropped_width), data_format=data_format, **kwargs)
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PIL.Image.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BILINEAR`):
+ Resampling filter to use when resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
+ return resize(
+ image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
+ )
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ mean (`float` or `List[float]`):
+ Image mean.
+ std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[Dict[str, int]] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: PILImageResampling = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image to `crop_size`.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Desired output size after applying the center crop.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size)
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_center_crop and crop_size is None:
+ raise ValueError("If `do_center_crop` is set to `True`, `crop_size` must be provided.")
+
+ if do_resize and size is None:
+ raise ValueError("Size must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and image standard deviation must be specified if do_normalize is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_center_crop:
+ images = [self.center_crop(image, crop_size, size=size) for image in images]
+
+ if do_resize:
+ images = [self.resize(image=image, size=size, resample=resample) for image in images]
+
+ if do_rescale:
+ images = [self.rescale(image=image, scale=rescale_factor) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
+
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/src/transformers/models/poolformer/feature_extraction_poolformer.py b/src/transformers/models/poolformer/feature_extraction_poolformer.py
index d8e632498f..72a5ec69f6 100644
--- a/src/transformers/models/poolformer/feature_extraction_poolformer.py
+++ b/src/transformers/models/poolformer/feature_extraction_poolformer.py
@@ -14,161 +14,11 @@
# limitations under the License.
"""Feature extractor class for PoolFormer."""
-import math
-from typing import Optional, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import (
- IMAGENET_DEFAULT_MEAN,
- IMAGENET_DEFAULT_STD,
- ImageFeatureExtractionMixin,
- ImageInput,
- is_torch_tensor,
-)
-from ...utils import TensorType, logging
+from ...utils import logging
+from .image_processing_poolformer import PoolFormerImageProcessor
logger = logging.get_logger(__name__)
-class PoolFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a PoolFormer feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize_and_center_crop (`bool`, *optional*, defaults to `True`):
- Whether to resize the shortest edge of the image and center crop the input to a certain `size`.
- size (`int` or `Tuple(int)`, *optional*, defaults to 224):
- Center crop the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be center cropped to (size, size). Only has an effect if
- `do_resize_and_center_crop` is set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- crop_pct (`float`, *optional*, defaults to `0.9`):
- The percentage of the image to crop from the center. Only has an effect if `do_resize_and_center_crop` is
- set to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with `image_mean` and `image_std`.
- image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize_and_center_crop=True,
- size=224,
- resample=PILImageResampling.BICUBIC,
- crop_pct=0.9,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize_and_center_crop = do_resize_and_center_crop
- self.size = size
- self.resample = resample
- self.crop_pct = crop_pct
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
-
- def __call__(
- self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (resizing + center cropping + normalization)
- if self.do_resize_and_center_crop and self.size is not None and self.crop_pct is not None:
- if isinstance(self.size, (tuple, list)):
- assert len(self.size) == 2
- if self.size[-1] == self.size[-2]:
- scale_size = int(math.floor(self.size[0] / self.crop_pct))
- else:
- scale_size = tuple([int(x / self.crop_pct) for x in self.size])
- else:
- scale_size = int(math.floor(self.size / self.crop_pct))
-
- # resize shortest edge of the image
- images = [
- self.resize(image=image, size=scale_size, resample=self.resample, default_to_square=False)
- for image in images
- ]
- # center crop
- images = [self.center_crop(image, size=self.size) for image in images]
-
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+PoolFormerFeatureExtractor = PoolFormerImageProcessor
diff --git a/src/transformers/models/poolformer/image_processing_poolformer.py b/src/transformers/models/poolformer/image_processing_poolformer.py
new file mode 100644
index 0000000000..e8fd6db240
--- /dev/null
+++ b/src/transformers/models/poolformer/image_processing_poolformer.py
@@ -0,0 +1,382 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for PoolFormer."""
+
+import math
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from transformers import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ center_crop,
+ get_resize_output_image_size,
+ normalize,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class PoolFormerImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a PoolFormer image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 256}`):
+ Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. If crop_pct is
+ unset:
+ - size is `{"height": h, "width": w}`: the image is resized to `(h, w)`.
+ - size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the
+ aspect ratio.
+
+ If crop_pct is set:
+ - size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)),
+ int(floor(w/crop_pct)))`
+ - size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
+ whilst maintaining the aspect ratio.
+ - size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
+ whilst maintaining the aspect ratio.
+ crop_pct (`float`, *optional*, defaults to `0.9`):
+ Percentage of the image to crop from the center. Can be overridden by `crop_pct` in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
+ is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in the `preprocess`
+ method.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
+ Size of the image after applying center crop. Only has an effect if `do_center_crop` is set to `True`. Can
+ be overridden by the `crop_size` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
+ `preprocess` method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ crop_pct: int = 0.9,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_rescale: bool = True,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 256}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
+ crop_size = get_size_dict(crop_size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.crop_pct = crop_pct
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ crop_pct: Optional[float] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ If crop_pct is unset:
+ - size is `{"height": h, "width": w}`: the image is resized to `(h, w)`.
+ - size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the
+ aspect ratio.
+
+ if crop_pct is set:
+ - size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)),
+ int(floor(w/crop_pct)))`
+ - size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
+ whilst maintaining the aspect ratio.
+ - size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)`
+ whilst maintaining the aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ crop_pct (`float`, *optional*):
+ Percentage of the image that will be cropped from the center. If set, the image is resized
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size, default_to_square=False)
+ if "shortest_edge" not in size and ("height" not in size or "width" not in size):
+ raise ValueError(f"size must contain 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
+ if crop_pct is not None:
+ if "shortest_edge" in size:
+ scale_size = int(math.floor(size["shortest_edge"] / crop_pct))
+ elif "height" in size and "width" in size:
+ if size["height"] == size["width"]:
+ scale_size = int(math.floor(size["height"] / crop_pct))
+ else:
+ scale_size = (
+ int(math.floor(size["height"] / crop_pct)),
+ int(math.floor(size["width"] / crop_pct)),
+ )
+ else:
+ raise ValueError("Invalid size for resize: {}".format(size))
+
+ output_size = get_resize_output_image_size(image, size=scale_size, default_to_square=False)
+ else:
+ if "shortest_edge" in size:
+ output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
+ elif "height" in size and "width" in size:
+ output_size = (size["height"], size["width"])
+ else:
+ raise ValueError("Invalid size for resize: {}".format(size))
+
+ return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Center crop an image to (size["height"], size["width"]). If the input size is smaller than `crop_size` along
+ any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ image_mean (`float` or `List[float]`):
+ Image mean.
+ image_std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ crop_pct: int = None,
+ resample: PILImageResampling = None,
+ do_center_crop: bool = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after applying resize.
+ crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
+ Percentage of the image to crop. Only has an effect if `do_resize` is set to `True`.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the image after applying center crop.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ crop_pct = crop_pct if crop_pct is not None else self.crop_pct
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size)
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None or resample is None:
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_center_crop and crop_pct is None:
+ raise ValueError("Crop_pct must be specified if do_center_crop is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_resize:
+ images = [self.resize(image=image, size=size, crop_pct=crop_pct, resample=resample) for image in images]
+
+ if do_center_crop:
+ images = [self.center_crop(image=image, size=crop_size) for image in images]
+
+ if do_rescale:
+ images = [self.rescale(image=image, scale=rescale_factor) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
+
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/src/transformers/models/segformer/feature_extraction_segformer.py b/src/transformers/models/segformer/feature_extraction_segformer.py
index 980a33e46e..e43cc8cb58 100644
--- a/src/transformers/models/segformer/feature_extraction_segformer.py
+++ b/src/transformers/models/segformer/feature_extraction_segformer.py
@@ -14,248 +14,11 @@
# limitations under the License.
"""Feature extractor class for SegFormer."""
-from typing import List, Optional, Tuple, Union
+from ...utils import logging
+from .image_processing_segformer import SegformerImageProcessor
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import (
- IMAGENET_DEFAULT_MEAN,
- IMAGENET_DEFAULT_STD,
- ImageFeatureExtractionMixin,
- ImageInput,
- is_torch_tensor,
-)
-from ...utils import TensorType, is_torch_available, logging
-
-
-if is_torch_available():
- import torch
logger = logging.get_logger(__name__)
-class SegformerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a SegFormer feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input based on a certain `size`.
- size (`int` or `Tuple(int)`, *optional*, defaults to 512):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
- set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with mean and standard deviation.
- image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
- The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
- image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
- The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
- ImageNet std.
- reduce_labels (`bool`, *optional*, defaults to `False`):
- Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
- used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
- background label will be replaced by 255.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=512,
- resample=PILImageResampling.BILINEAR,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- reduce_labels=False,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
- self.reduce_labels = reduce_labels
-
- def __call__(
- self,
- images: ImageInput,
- segmentation_maps: ImageInput = None,
- return_tensors: Optional[Union[str, TensorType]] = None,
- **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s) and optional corresponding segmentation maps.
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is
- the number of channels, H and W are image height and width.
-
- segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
- Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- - **labels** -- Optional labels to be fed to a model (when `segmentation_maps` are provided)
- """
- # Input type checking for clearer error
- valid_images = False
- valid_segmentation_maps = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- # Check that segmentation maps has a valid type
- if segmentation_maps is not None:
- if isinstance(segmentation_maps, (Image.Image, np.ndarray)) or is_torch_tensor(segmentation_maps):
- valid_segmentation_maps = True
- elif isinstance(segmentation_maps, (list, tuple)):
- if (
- len(segmentation_maps) == 0
- or isinstance(segmentation_maps[0], (Image.Image, np.ndarray))
- or is_torch_tensor(segmentation_maps[0])
- ):
- valid_segmentation_maps = True
-
- if not valid_segmentation_maps:
- raise ValueError(
- "Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single"
- " example),`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of"
- " examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
- if segmentation_maps is not None:
- segmentation_maps = [segmentation_maps]
-
- # reduce zero label if needed
- if self.reduce_labels:
- if segmentation_maps is not None:
- for idx, map in enumerate(segmentation_maps):
- if not isinstance(map, np.ndarray):
- map = np.array(map)
- # avoid using underflow conversion
- map[map == 0] = 255
- map = map - 1
- map[map == 254] = 255
- segmentation_maps[idx] = Image.fromarray(map.astype(np.uint8))
-
- # transformations (resizing + normalization)
- if self.do_resize and self.size is not None:
- images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
- if segmentation_maps is not None:
- segmentation_maps = [
- self.resize(map, size=self.size, resample=Image.NEAREST) for map in segmentation_maps
- ]
-
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
-
- if segmentation_maps is not None:
- labels = []
- for map in segmentation_maps:
- if not isinstance(map, np.ndarray):
- map = np.array(map)
- labels.append(map.astype(np.int64))
- # cast to np.int64
- data["labels"] = labels
-
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
-
- def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
- """
- Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports
- PyTorch.
-
- Args:
- outputs ([`SegformerForSemanticSegmentation`]):
- Raw outputs of the model.
- target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
- List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
- None, predictions will not be resized.
- Returns:
- semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
- segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
- specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
- """
- logits = outputs.logits
-
- # Resize logits and compute semantic segmentation maps
- if target_sizes is not None:
- if len(logits) != len(target_sizes):
- raise ValueError(
- "Make sure that you pass in as many target sizes as the batch dimension of the logits"
- )
-
- if is_torch_tensor(target_sizes):
- target_sizes = target_sizes.numpy()
-
- semantic_segmentation = []
-
- for idx in range(len(logits)):
- resized_logits = torch.nn.functional.interpolate(
- logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
- )
- semantic_map = resized_logits[0].argmax(dim=0)
- semantic_segmentation.append(semantic_map)
- else:
- semantic_segmentation = logits.argmax(dim=1)
- semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
-
- return semantic_segmentation
+SegformerFeatureExtractor = SegformerImageProcessor
diff --git a/src/transformers/models/segformer/image_processing_segformer.py b/src/transformers/models/segformer/image_processing_segformer.py
new file mode 100644
index 0000000000..f67b669092
--- /dev/null
+++ b/src/transformers/models/segformer/image_processing_segformer.py
@@ -0,0 +1,488 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Segformer."""
+
+import warnings
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+
+from transformers.utils import is_torch_available, is_torch_tensor, is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL.Image
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+class SegformerImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Segformer image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
+ used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
+ background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
+ `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_reduce_labels: bool = False,
+ **kwargs
+ ) -> None:
+ if "reduce_labels" in kwargs:
+ warnings.warn(
+ "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
+ "`do_reduce_labels` instead.",
+ FutureWarning,
+ )
+ do_reduce_labels = kwargs.pop("reduce_labels")
+
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 512, "width": 512}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.do_reduce_labels = do_reduce_labels
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ return resize(
+ image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
+ )
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
+ any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ image_mean (`float` or `List[float]`):
+ Image mean.
+ image_std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def reduce_label(self, label: ImageInput) -> np.ndarray:
+ label = to_numpy_array(label)
+ # Avoid using underflow conversion
+ label[label == 0] = 255
+ label = label - 1
+ label[label == 254] = 255
+ return label
+
+ def _preprocess(
+ self,
+ image: ImageInput,
+ do_reduce_labels: bool,
+ do_resize: bool,
+ do_rescale: bool,
+ do_normalize: bool,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ rescale_factor: Optional[float] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ ):
+ if do_reduce_labels:
+ image = self.reduce_label(image)
+
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std)
+
+ return image
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """Preprocesses a single image."""
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+ image = self._preprocess(
+ image=image,
+ do_reduce_labels=False,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ )
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format)
+ return image
+
+ def _preprocess_mask(
+ self,
+ segmentation_map: ImageInput,
+ do_reduce_labels: bool = None,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ ) -> np.ndarray:
+ """Preprocesses a single mask."""
+ segmentation_map = to_numpy_array(segmentation_map)
+ # Add channel dimension if missing - needed for certain transformations
+ added_channel_dim = False
+ if segmentation_map.ndim == 2:
+ added_channel_dim = True
+ segmentation_map = segmentation_map[None, ...]
+ # reduce zero label if needed
+ segmentation_map = self._preprocess(
+ image=segmentation_map,
+ do_reduce_labels=do_reduce_labels,
+ do_resize=do_resize,
+ resample=PIL.Image.NEAREST,
+ size=size,
+ do_rescale=False,
+ do_normalize=False,
+ )
+ # Remove extra channel dimension if added for processing
+ if added_channel_dim:
+ segmentation_map = segmentation_map.squeeze(0)
+ segmentation_map = segmentation_map.astype(np.int64)
+ return segmentation_map
+
+ def __call__(self, images, segmentation_maps=None, **kwargs):
+ """
+ Preprocesses a batch of images and optionally segmentation maps.
+
+ Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
+ passed in as positional arguments.
+ """
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional[PILImageResampling] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_reduce_labels: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ segmentation_maps (`ImageInput`, *optional*):
+ Segmentation map to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after `resize` is applied.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
+ ADE20k). The background label will be replaced by 255.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
+ resample = resample if resample is not None else self.resample
+ size = size if size is not None else self.size
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ if not is_batched(images):
+ images = [images]
+ segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if segmentation_maps is not None and not valid_images(segmentation_maps):
+ raise ValueError(
+ "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None or resample is None:
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ images = [
+ self._preprocess_image(
+ image=img,
+ do_resize=do_resize,
+ resample=resample,
+ size=size,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ )
+ for img in images
+ ]
+
+ data = {"pixel_values": images}
+
+ if segmentation_maps is not None:
+ segmentation_maps = [
+ self._preprocess_mask(
+ segmentation_map=segmentation_map,
+ do_reduce_labels=do_reduce_labels,
+ do_resize=do_resize,
+ resample=PIL.Image.NEAREST,
+ size=size,
+ )
+ for segmentation_map in segmentation_maps
+ ]
+ data["labels"] = segmentation_maps
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
+ """
+ Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports
+ PyTorch.
+
+ Args:
+ outputs ([`SegformerForSemanticSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If left to
+ None, predictions will not be resized.
+ Returns:
+ semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
+ """
+ # TODO: add support for other frameworks
+ logits = outputs.logits
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if len(logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ if is_torch_tensor(target_sizes):
+ target_sizes = target_sizes.numpy()
+
+ semantic_segmentation = []
+
+ for idx in range(len(logits)):
+ resized_logits = torch.nn.functional.interpolate(
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = logits.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
diff --git a/src/transformers/models/videomae/feature_extraction_videomae.py b/src/transformers/models/videomae/feature_extraction_videomae.py
index 86e49f9b2e..f5bbdd388f 100644
--- a/src/transformers/models/videomae/feature_extraction_videomae.py
+++ b/src/transformers/models/videomae/feature_extraction_videomae.py
@@ -14,159 +14,11 @@
# limitations under the License.
"""Feature extractor class for VideoMAE."""
-from typing import Optional, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
-from ...utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, TensorType, logging
+from ...utils import logging
+from .image_processing_videomae import VideoMAEImageProcessor
logger = logging.get_logger(__name__)
-class VideoMAEFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a VideoMAE feature extractor. This feature extractor can be used to prepare videos for the model.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the shorter edge of the input to a certain `size`.
- size (`int`, *optional*, defaults to 224):
- Resize the shorter edge of the input to the given size. Only has an effect if `do_resize` is set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_center_crop (`bool`, *optional*, defaults to `True`):
- Whether to center crop the input to a certain `size`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with mean and standard deviation.
- image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=224,
- resample=PILImageResampling.BILINEAR,
- do_center_crop=True,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_center_crop = do_center_crop
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
-
- def resize_video(self, video, size, resample="bilinear"):
- return [self.resize(frame, size, resample, default_to_square=False) for frame in video]
-
- def crop_video(self, video, size):
- return [self.center_crop(frame, size) for frame in video]
-
- def normalize_video(self, video, mean, std):
- # video can be a list of PIL images, list of NumPy arrays or list of PyTorch tensors
- # first: convert to list of NumPy arrays
- video = [self.to_numpy_array(frame) for frame in video]
-
- # second: stack to get (num_frames, num_channels, height, width)
- video = np.stack(video, axis=0)
-
- # third: normalize
- if not isinstance(mean, np.ndarray):
- mean = np.array(mean).astype(video.dtype)
- if not isinstance(std, np.ndarray):
- std = np.array(std).astype(video.dtype)
-
- return (video - mean[None, :, None, None]) / std[None, :, None, None]
-
- def __call__(
- self, videos: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several video(s).
-
-
-
- NumPy arrays are converted to PIL images when resizing, so the most efficient is to pass PIL images.
-
-
-
- Args:
- videos (`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, `List[List[PIL.Image.Image]]`, `List[List[np.ndarrray]]`,:
- `List[List[torch.Tensor]]`): The video or batch of videos to be prepared. Each video should be a list
- of frames, which can be either PIL images or NumPy arrays. In case of NumPy arrays/PyTorch tensors,
- each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of
- channels.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, num_frames,
- height, width).
- """
- # Input type checking for clearer error
- valid_videos = False
- is_batched = False
-
- # Check that videos have a valid type
- if isinstance(videos, (list, tuple)):
- if isinstance(videos[0], (Image.Image, np.ndarray)) or is_torch_tensor(videos[0]):
- valid_videos = True
- elif isinstance(videos[0], (list, tuple)) and (
- isinstance(videos[0][0], (Image.Image, np.ndarray)) or is_torch_tensor(videos[0][0])
- ):
- valid_videos = True
- is_batched = True
-
- if not valid_videos:
- raise ValueError(
- "Videos must of type `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]` (single"
- " example), `List[List[PIL.Image.Image]]`, `List[List[np.ndarray]]`, `List[List[torch.Tensor]]` (batch"
- " of examples)."
- )
-
- if not is_batched:
- videos = [videos]
-
- # transformations (resizing + center cropping + normalization)
- if self.do_resize and self.size is not None:
- videos = [self.resize_video(video, size=self.size, resample=self.resample) for video in videos]
- if self.do_center_crop and self.size is not None:
- videos = [self.crop_video(video, size=self.size) for video in videos]
- if self.do_normalize:
- videos = [self.normalize_video(video, mean=self.image_mean, std=self.image_std) for video in videos]
-
- # return as BatchFeature
- data = {"pixel_values": videos}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+VideoMAEFeatureExtractor = VideoMAEImageProcessor
diff --git a/src/transformers/models/videomae/image_processing_videomae.py b/src/transformers/models/videomae/image_processing_videomae.py
new file mode 100644
index 0000000000..25fc2bb88a
--- /dev/null
+++ b/src/transformers/models/videomae/image_processing_videomae.py
@@ -0,0 +1,380 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for VideoMAE."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from transformers.utils import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ center_crop,
+ get_resize_output_image_size,
+ normalize,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ is_valid_image,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+def make_batched(videos) -> List[List[ImageInput]]:
+ if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
+ return videos
+
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
+ return [videos]
+
+ elif is_valid_image(videos):
+ return [[videos]]
+
+ raise ValueError(f"Could not make batched video from {videos}")
+
+
+class VideoMAEImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a VideoMAE image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the output image after resizing. The shortest edge of the image will be resized to
+ `size["shortest_edge"]` while maintaining the aspect ratio of the original image. Can be overriden by
+ `size` in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop`
+ parameter in the `preprocess` method.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter
+ in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will
+ have the size `(h, w)`. If `size` is of the form `{"shortest_edge": s}`, the output image will have its
+ shortest edge of length `s` while keeping the aspect ratio of the original image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ if "shortest_edge" in size:
+ output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False)
+ elif "height" in size and "width" in size:
+ output_size = (size["height"], size["width"])
+ else:
+ raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
+ return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `size` along any
+ edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size)
+ return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ image_mean (`float` or `List[float]`):
+ Image mean.
+ image_std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: bool = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ ) -> np.ndarray:
+ """Preprocesses a single image."""
+ if do_resize and size is None or resample is None:
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_center_crop and crop_size is None:
+ raise ValueError("Crop size must be specified if do_center_crop is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample)
+
+ if do_center_crop:
+ image = self.center_crop(image, size=crop_size)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std)
+
+ image = to_channel_dimension_format(image, data_format)
+ return image
+
+ def preprocess(
+ self,
+ videos: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: bool = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after applying resize.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`):
+ Whether to centre crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the image after applying the centre crop.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the inferred channel dimension format of the input image.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size)
+
+ if not valid_images(videos):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ videos = make_batched(videos)
+
+ videos = [
+ [
+ self._preprocess_image(
+ image=img,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ )
+ for img in video
+ ]
+ for video in videos
+ ]
+
+ data = {"pixel_values": videos}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/src/transformers/models/vilt/feature_extraction_vilt.py b/src/transformers/models/vilt/feature_extraction_vilt.py
index c7b4b80e75..70fe34bfe5 100644
--- a/src/transformers/models/vilt/feature_extraction_vilt.py
+++ b/src/transformers/models/vilt/feature_extraction_vilt.py
@@ -14,282 +14,11 @@
# limitations under the License.
"""Feature extractor class for ViLT."""
-from typing import List, Optional, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import (
- IMAGENET_STANDARD_MEAN,
- IMAGENET_STANDARD_STD,
- ImageFeatureExtractionMixin,
- ImageInput,
- is_torch_tensor,
-)
-from ...utils import TensorType, is_torch_available, logging
-
-
-if is_torch_available():
- import torch
+from ...utils import logging
+from .image_processing_vilt import ViltImageProcessor
logger = logging.get_logger(__name__)
-class ViltFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a ViLT feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input based on `size`.
- size (`int`, *optional*, defaults to 384):
- Resize the shorter side of the input to the given size. Should be an integer. The longer side will be
- limited to under int((1333 / 800) * size) while preserving the aspect ratio. Only has an effect if
- `do_resize` is set to `True`.
- size_divisor (`int`, *optional*, defaults to 32):
- The size by which to make sure both the height and width can be divided.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with mean and standard deviation.
- image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- """
-
- model_input_names = ["pixel_values", "pixel_mask"]
-
- def __init__(
- self,
- do_resize=True,
- size=384,
- size_divisor=32,
- resample=PILImageResampling.BICUBIC,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.size_divisor = size_divisor
- self.resample = resample
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
-
- def _resize(self, image, shorter=800, longer=1333, size_divisor=32, resample=PILImageResampling.BICUBIC):
- """
- Resizes the shorter edge of `image` to `shorter` and limits the longer edge to under `longer`, while preserving
- the aspect ratio. Also makes sure that both the height and width can be divided by `size_divisor`.
-
- Based on original implementation:
- https://github.com/dandelin/ViLT/blob/3db8b5035464afee84d951bf6322e1b27f1d072d/vilt/transforms/utils.py#L5
-
- Args:
- image (`PIL.Image`):
- The image to resize.
- shorter (`int`, *optional*, defaults to `800`):
- The size to which to resize the shorter side of the image.
- longer (`int`, *optional*, defaults to `1333`):
- The size by which to limit the longer side of the image, while preserving the aspect ratio.
- size_divisor (`int`, *optional*, defaults to `32`):
- The size by which both the height and the width must be divisible.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
- An optional resampling filter.
- """
- if not isinstance(image, Image.Image):
- image = self.to_pil_image(image)
-
- w, h = image.size
- min_size = shorter
- max_size = longer
- scale = min_size / min(w, h)
- if h < w:
- newh, neww = min_size, scale * w
- else:
- newh, neww = scale * h, min_size
-
- if max(newh, neww) > max_size:
- scale = max_size / max(newh, neww)
- newh = newh * scale
- neww = neww * scale
-
- newh, neww = int(newh + 0.5), int(neww + 0.5)
- newh, neww = newh // size_divisor * size_divisor, neww // size_divisor * size_divisor
-
- return self.resize(image, size=(neww, newh), resample=resample)
-
- def _max_by_axis(self, the_list):
- # type: (List[List[int]]) -> List[int]
- maxes = the_list[0]
- for sublist in the_list[1:]:
- for index, item in enumerate(sublist):
- maxes[index] = max(maxes[index], item)
- return maxes
-
- def pad_and_create_pixel_mask(
- self, pixel_values_list: List["torch.Tensor"], return_tensors: Optional[Union[str, TensorType]] = None
- ):
- """
- Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
-
- Args:
- pixel_values_list (`List[torch.Tensor]`):
- List of images (pixel values) to be padded. Each image should be a tensor of shape (C, H, W).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
- objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model.
- - **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
- *"pixel_mask"* is in `self.model_input_names`).
- """
-
- max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
- c, h, w = max_size
- padded_images = []
- pixel_mask = []
- for image in pixel_values_list:
- # create padded image
- padded_image = np.zeros((c, h, w), dtype=np.float32)
- padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
- padded_images.append(padded_image)
- # create pixel mask
- mask = np.zeros((h, w), dtype=np.int64)
- mask[: image.shape[1], : image.shape[2]] = True
- pixel_mask.append(mask)
-
- # return as BatchFeature
- data = {"pixel_values": padded_images, "pixel_mask": pixel_mask}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
-
- def __call__(
- self,
- images: ImageInput,
- pad_and_return_pixel_mask: Optional[bool] = True,
- return_tensors: Optional[Union[str, TensorType]] = None,
- **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
- Whether or not to pad images up to the largest image in a batch and create a pixel mask.
-
- If left to the default, will return a pixel mask that is:
-
- - 1 for pixels that are real (i.e. **not masked**),
- - 0 for pixels that are padding (i.e. **masked**).
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- - **pixel_mask** -- Pixel mask to be fed to a model (when `return_pixel_mask=True` or if *"pixel_mask"* is
- in `self.model_input_names`).
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (resizing + normalization)
- if self.do_resize and self.size is not None:
- longer = int((1333 / 800) * self.size)
- images = [
- self._resize(
- image=image,
- shorter=self.size,
- longer=longer,
- size_divisor=self.size_divisor,
- resample=self.resample,
- )
- for image in images
- ]
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- if pad_and_return_pixel_mask:
- # pad images up to largest image in batch and create pixel_mask
- max_size = self._max_by_axis([list(image.shape) for image in images])
- c, h, w = max_size
- padded_images = []
- pixel_mask = []
- for image in images:
- # create padded image
- padded_image = np.zeros((c, h, w), dtype=np.float32)
- padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
- padded_images.append(padded_image)
- # create pixel mask
- mask = np.zeros((h, w), dtype=np.int64)
- mask[: image.shape[1], : image.shape[2]] = True
- pixel_mask.append(mask)
- images = padded_images
-
- # return as BatchFeature
- data = {}
- data["pixel_values"] = images
- if pad_and_return_pixel_mask:
- data["pixel_mask"] = pixel_mask
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+ViltFeatureExtractor = ViltImageProcessor
diff --git a/src/transformers/models/vilt/image_processing_vilt.py b/src/transformers/models/vilt/image_processing_vilt.py
new file mode 100644
index 0000000000..c2ba7e68bb
--- /dev/null
+++ b/src/transformers/models/vilt/image_processing_vilt.py
@@ -0,0 +1,487 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Vilt."""
+
+import warnings
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+
+from transformers.utils import is_vision_available
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+def max_across_indices(values: Iterable[Any]) -> List[Any]:
+ """
+ Return the maximum value across all indices of an iterable of values.
+ """
+ return [max(values_i) for values_i in zip(*values)]
+
+
+def pad(
+ image: np.ndarray,
+ output_size: Tuple[int, int],
+ input_channel_dimension: Optional[ChannelDimension] = None,
+ data_format: Optional[ChannelDimension] = None,
+) -> np.ndarray:
+ """
+ Pad the bottom and right of the image with zeros to the output size.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ output_size (`Tuple[int, int]`):
+ Output size of the image.
+ input_channel_dimension (`ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be inferred from the input image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ if input_channel_dimension is None:
+ input_channel_dimension = infer_channel_dimension_format(image)
+
+ output_height, output_width = output_size
+ input_height, input_width = get_image_size(image)
+ pad_bottom = output_height - input_height
+ pad_right = output_width - input_width
+
+ if input_channel_dimension == ChannelDimension.FIRST:
+ padded_image = np.pad(image, [(0, 0), (0, pad_bottom), (0, pad_right)], mode="constant", constant_values=0)
+ elif input_channel_dimension == ChannelDimension.LAST:
+ padded_image = np.pad(image, [(0, pad_bottom), (0, pad_right), (0, 0)], mode="constant", constant_values=0)
+ else:
+ raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
+
+ if data_format is not None:
+ padded_image = to_channel_dimension_format(padded_image, data_format)
+
+ return padded_image
+
+
+def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+
+ Args:
+ image (`np.ndarray`):
+ Image to make the pixel mask for.
+ output_size (`Tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = get_image_size(image)
+ mask = np.zeros(output_size, dtype=np.int64)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+def get_max_dimensions(images: List[np.ndarray]) -> List[int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ input_channel_dimension = infer_channel_dimension_format(images[0])
+
+ if input_channel_dimension == ChannelDimension.FIRST:
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
+ elif input_channel_dimension == ChannelDimension.LAST:
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
+ else:
+ raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
+ return (max_height, max_width)
+
+
+def get_resize_output_image_size(
+ input_image: np.ndarray, shorter: int = 800, longer: int = 1333, size_divisor: int = 32
+) -> Tuple[int, int]:
+ input_height, input_width = get_image_size(input_image)
+ min_size, max_size = shorter, longer
+
+ scale = min_size / min(input_height, input_width)
+
+ if input_height < input_width:
+ new_height = min_size
+ new_width = scale * input_width
+ else:
+ new_height = scale * input_height
+ new_width = min_size
+
+ if max(new_height, new_width) > max_size:
+ scale = max_size / max(new_height, new_width)
+ new_height = scale * new_height
+ new_width = scale * new_width
+
+ new_height, new_width = int(new_height + 0.5), int(new_width + 0.5)
+ new_height = new_height // size_divisor * size_divisor
+ new_width = new_width // size_divisor * size_divisor
+
+ return new_height, new_width
+
+
+class ViltImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a ViLT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
+ Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
+ `int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
+ `do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
+ size_divisor (`int`, *optional*, defaults to 32):
+ The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
+ is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+ overridden by the `resample` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+ overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
+ the `do_pad` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ size_divisor: int = 32,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: bool = True,
+ **kwargs
+ ) -> None:
+ if "pad_and_return_pixel_mask" in kwargs:
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
+
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 384}
+ size = get_size_dict(size, default_to_square=False)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.size_divisor = size_divisor
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.do_pad = do_pad
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ size_divisor: int = 32,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ Resizes the shorter side of the image to `size["shortest_edge"]` while preserving the aspect ratio. If the
+ longer side is larger than the max size `(int(`size["shortest_edge"]` * 1333 / 800))`, the longer side is then
+ resized to the max size while preserving the aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Controls the size of the output image. Should be of the form `{"shortest_edge": int}`.
+ size_divisor (`int`, defaults to 32):
+ The image is resized to a size that is a multiple of this value.
+ resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ size = get_size_dict(size, default_to_square=False)
+ if "shortest_edge" not in size:
+ raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
+ shorter = size["shortest_edge"]
+ longer = int(1333 / 800 * shorter)
+ output_size = get_resize_output_image_size(image, shorter=shorter, longer=longer, size_divisor=size_divisor)
+ return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ):
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ mean (`float` or `List[float]`):
+ Image mean.
+ std (`float` or `List[float]`):
+ Image standard deviation.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def pad(
+ self,
+ images: List[np.ndarray],
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ ) -> BatchFeature:
+ """
+ Pads a batch of images with zeros to the size of largest height and width in the batch and optionally returns
+ their corresponding pixel mask.
+
+ Args:
+ images (`List[np.ndarray]`):
+ Batch of images to pad.
+ return_pixel_mask (`bool`, *optional*, defaults to `False`):
+ Whether to return the pixel mask.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ pad_size = get_max_dimensions(images)
+ padded_images = [pad(image=image, output_size=pad_size, data_format=data_format) for image in images]
+ data = {"pixel_values": padded_images}
+ if return_pixel_mask:
+ masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
+ data["pixel_mask"] = masks
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def pad_and_create_pixel_mask(
+ self,
+ pixel_values_list: List[ImageInput],
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ ) -> BatchFeature:
+ """
+ Pads a batch of images with zeros to the size of largest height and width in the batch and returns their
+ corresponding pixel mask.
+
+ Args:
+ images (`List[np.ndarray]`):
+ Batch of images to pad.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ """
+ warnings.warn(
+ "This method is deprecated and will be removed in v4.26.0. Please use pad instead.", FutureWarning
+ )
+ # pad expects a list of np.ndarray, but the previous feature extractors expected torch tensors
+ images = [to_numpy_array(image) for image in pixel_values_list]
+ return self.pad(
+ images=images,
+ return_pixel_mask=True,
+ return_tensors=return_tensors,
+ data_format=data_format,
+ )
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ size_divisor: Optional[int] = None,
+ resample: PILImageResampling = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Controls the size of the image after `resize`. The shortest edge of the image is resized to
+ `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
+ is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
+ edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
+ size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
+ The image is resized to a size that is a multiple of this value.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to normalize the image by if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image to the (max_height, max_width) in the batch. If `True`, a pixel mask is also
+ created and returned.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size_divisor = size_divisor if size_divisor is not None else self.size_divisor
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_pad = do_pad if do_pad is not None else self.do_pad
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None or resample is None:
+ raise ValueError("Size and resample must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("Image mean and std must be specified if do_normalize is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_resize:
+ images = [
+ self.resize(image=image, size=size, size_divisor=size_divisor, resample=resample) for image in images
+ ]
+
+ if do_rescale:
+ images = [self.rescale(image=image, scale=rescale_factor) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
+
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ if do_pad:
+ encoded_outputs = self.pad(images, return_pixel_mask=True, return_tensors=return_tensors)
+ else:
+ encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+
+ return encoded_outputs
diff --git a/src/transformers/models/vit/feature_extraction_vit.py b/src/transformers/models/vit/feature_extraction_vit.py
index be1bd66a12..66f01eb7de 100644
--- a/src/transformers/models/vit/feature_extraction_vit.py
+++ b/src/transformers/models/vit/feature_extraction_vit.py
@@ -14,139 +14,12 @@
# limitations under the License.
"""Feature extractor class for ViT."""
-from typing import Optional, Union
-
-import numpy as np
-from PIL import Image
-
-from transformers.image_utils import PILImageResampling
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...image_utils import (
- IMAGENET_STANDARD_MEAN,
- IMAGENET_STANDARD_STD,
- ImageFeatureExtractionMixin,
- ImageInput,
- is_torch_tensor,
-)
-from ...utils import TensorType, logging
+from ...utils import logging
+from .image_processing_vit import ViTImageProcessor
logger = logging.get_logger(__name__)
-class ViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
- r"""
- Constructs a ViT feature extractor.
-
- This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the input to a certain `size`.
- size (`int` or `Tuple(int)`, *optional*, defaults to 224):
- Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
- integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
- set to `True`.
- resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
- An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
- `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
- `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
- to `True`.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the input with mean and standard deviation.
- image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of means for each channel, to be used when normalizing images.
- image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
- The sequence of standard deviations for each channel, to be used when normalizing images.
- """
-
- model_input_names = ["pixel_values"]
-
- def __init__(
- self,
- do_resize=True,
- size=224,
- resample=PILImageResampling.BILINEAR,
- do_normalize=True,
- image_mean=None,
- image_std=None,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.size = size
- self.resample = resample
- self.do_normalize = do_normalize
- self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
- self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
-
- def __call__(
- self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several image(s).
-
-
-
- NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
- PIL images.
-
-
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
- width).
- """
- # Input type checking for clearer error
- valid_images = False
-
- # Check that images has a valid type
- if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
- valid_images = True
- elif isinstance(images, (list, tuple)):
- if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
- valid_images = True
-
- if not valid_images:
- raise ValueError(
- "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
- "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
- )
-
- is_batched = bool(
- isinstance(images, (list, tuple))
- and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
- )
-
- if not is_batched:
- images = [images]
-
- # transformations (resizing + normalization)
- if self.do_resize and self.size is not None:
- images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
- if self.do_normalize:
- images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
-
- # return as BatchFeature
- data = {"pixel_values": images}
- encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
-
- return encoded_inputs
+# Feature extractor for ViT is being replaced by image processor
+ViTFeatureExtractor = ViTImageProcessor
diff --git a/src/transformers/models/vit/image_processing_vit.py b/src/transformers/models/vit/image_processing_vit.py
new file mode 100644
index 0000000000..5dcb04ef8b
--- /dev/null
+++ b/src/transformers/models/vit/image_processing_vit.py
@@ -0,0 +1,275 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for ViT."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from transformers.utils.generic import TensorType
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ is_batched,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class ViTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a ViT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize:
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[Dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 224, "width": 224}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.do_rescale = do_rescale
+ self.do_normalize = do_normalize
+ self.size = size
+ self.resample = resample
+ self.rescale_factor = rescale_factor
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample:
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ return resize(
+ image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
+ )
+
+ def rescale(
+ self, image: np.ndarray, scale: float, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs
+ ) -> np.ndarray:
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`float`):
+ The scaling factor to rescale pixel values by.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The rescaled image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, List[float]],
+ std: Union[float, List[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ mean (`float` or `List[float]`):
+ Image mean to use for normalization.
+ std (`float` or `List[float]`):
+ Image standard deviation to use for normalization.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The normalized image.
+ """
+ return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
+ resizing.
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
+ an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size_dict = get_size_dict(size)
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None:
+ raise ValueError("Size must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_resize:
+ images = [self.resize(image=image, size=size_dict, resample=resample) for image in images]
+
+ if do_rescale:
+ images = [self.rescale(image=image, scale=rescale_factor) for image in images]
+
+ if do_normalize:
+ images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
+
+ images = [to_channel_dimension_format(image, data_format) for image in images]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
diff --git a/tests/models/beit/test_feature_extraction_beit.py b/tests/models/beit/test_feature_extraction_beit.py
index a9338aea1f..de9e552393 100644
--- a/tests/models/beit/test_feature_extraction_beit.py
+++ b/tests/models/beit/test_feature_extraction_beit.py
@@ -44,14 +44,16 @@ class BeitFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=20,
+ size=None,
do_center_crop=True,
- crop_size=18,
+ crop_size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
- reduce_labels=False,
+ do_reduce_labels=False,
):
+ size = size if size is not None else {"height": 20, "width": 20}
+ crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -65,7 +67,7 @@ class BeitFeatureExtractionTester(unittest.TestCase):
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
- self.reduce_labels = reduce_labels
+ self.do_reduce_labels = do_reduce_labels
def prepare_feat_extract_dict(self):
return {
@@ -76,7 +78,7 @@ class BeitFeatureExtractionTester(unittest.TestCase):
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
- "reduce_labels": self.reduce_labels,
+ "do_reduce_labels": self.do_reduce_labels,
}
@@ -141,8 +143,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -153,8 +155,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -173,8 +175,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -185,8 +187,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -205,8 +207,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -217,8 +219,8 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -239,16 +241,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@@ -262,16 +264,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
self.feature_extract_tester.batch_size,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@@ -287,16 +289,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@@ -312,16 +314,16 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
2,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
2,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
diff --git a/tests/models/clip/test_feature_extraction_clip.py b/tests/models/clip/test_feature_extraction_clip.py
index 8f36a65ae2..e9c169cf51 100644
--- a/tests/models/clip/test_feature_extraction_clip.py
+++ b/tests/models/clip/test_feature_extraction_clip.py
@@ -43,14 +43,16 @@ class CLIPFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=20,
+ size=None,
do_center_crop=True,
- crop_size=18,
+ crop_size=None,
do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
do_convert_rgb=True,
):
+ size = size if size is not None else {"shortest_edge": 20}
+ crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -151,8 +153,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -163,8 +165,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -183,8 +185,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -195,8 +197,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -215,8 +217,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -227,8 +229,8 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -276,8 +278,8 @@ class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, un
(
1,
self.expected_encoded_image_num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -288,7 +290,7 @@ class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, un
(
self.feature_extract_tester.batch_size,
self.expected_encoded_image_num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
diff --git a/tests/models/convnext/test_feature_extraction_convnext.py b/tests/models/convnext/test_feature_extraction_convnext.py
index f02341972b..1419280f97 100644
--- a/tests/models/convnext/test_feature_extraction_convnext.py
+++ b/tests/models/convnext/test_feature_extraction_convnext.py
@@ -43,12 +43,13 @@ class ConvNextFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=20,
+ size=None,
crop_pct=0.875,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
+ size = size if size is not None else {"shortest_edge": 20}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -113,8 +114,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["shortest_edge"],
+ self.feature_extract_tester.size["shortest_edge"],
),
)
@@ -125,8 +126,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["shortest_edge"],
+ self.feature_extract_tester.size["shortest_edge"],
),
)
@@ -145,8 +146,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["shortest_edge"],
+ self.feature_extract_tester.size["shortest_edge"],
),
)
@@ -157,8 +158,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["shortest_edge"],
+ self.feature_extract_tester.size["shortest_edge"],
),
)
@@ -177,8 +178,8 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["shortest_edge"],
+ self.feature_extract_tester.size["shortest_edge"],
),
)
@@ -189,7 +190,7 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["shortest_edge"],
+ self.feature_extract_tester.size["shortest_edge"],
),
)
diff --git a/tests/models/deit/test_feature_extraction_deit.py b/tests/models/deit/test_feature_extraction_deit.py
index 92a477f182..03b869a967 100644
--- a/tests/models/deit/test_feature_extraction_deit.py
+++ b/tests/models/deit/test_feature_extraction_deit.py
@@ -43,13 +43,16 @@ class DeiTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=20,
+ size=None,
do_center_crop=True,
- crop_size=18,
+ crop_size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
+ size = size if size is not None else {"height": 20, "width": 20}
+ crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
+
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -117,8 +120,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -129,8 +132,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -149,8 +152,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -161,8 +164,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -181,8 +184,8 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -193,7 +196,7 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
diff --git a/tests/models/dpt/test_feature_extraction_dpt.py b/tests/models/dpt/test_feature_extraction_dpt.py
index a0cf1cba23..bcfec4b2aa 100644
--- a/tests/models/dpt/test_feature_extraction_dpt.py
+++ b/tests/models/dpt/test_feature_extraction_dpt.py
@@ -43,11 +43,12 @@ class DPTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=18,
+ size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
+ size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -106,8 +107,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -118,8 +119,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -138,8 +139,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -150,8 +151,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -170,8 +171,8 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -182,7 +183,7 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
diff --git a/tests/models/flava/test_feature_extraction_flava.py b/tests/models/flava/test_feature_extraction_flava.py
index 11cfd936d0..fe0e6dca26 100644
--- a/tests/models/flava/test_feature_extraction_flava.py
+++ b/tests/models/flava/test_feature_extraction_flava.py
@@ -28,11 +28,10 @@ if is_torch_available():
import torch
if is_vision_available():
- from PIL import Image
+ import PIL
from transformers import FlavaFeatureExtractor
- from transformers.image_utils import PILImageResampling
- from transformers.models.flava.feature_extraction_flava import (
+ from transformers.models.flava.image_processing_flava import (
FLAVA_CODEBOOK_MEAN,
FLAVA_CODEBOOK_STD,
FLAVA_IMAGE_MEAN,
@@ -51,10 +50,12 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=224,
+ size=None,
do_center_crop=True,
- crop_size=224,
+ crop_size=None,
resample=None,
+ do_rescale=True,
+ rescale_factor=1 / 255,
do_normalize=True,
image_mean=FLAVA_IMAGE_MEAN,
image_std=FLAVA_IMAGE_STD,
@@ -65,23 +66,30 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
mask_group_min_aspect_ratio=0.3,
mask_group_max_aspect_ratio=None,
codebook_do_resize=True,
- codebook_size=112,
+ codebook_size=None,
codebook_resample=None,
codebook_do_center_crop=True,
- codebook_crop_size=112,
+ codebook_crop_size=None,
codebook_do_map_pixels=True,
codebook_do_normalize=True,
codebook_image_mean=FLAVA_CODEBOOK_MEAN,
codebook_image_std=FLAVA_CODEBOOK_STD,
):
+ size = size if size is not None else {"height": 224, "width": 224}
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
+ codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
+
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.do_resize = do_resize
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.size = size
- self.resample = resample if resample is not None else PILImageResampling.BICUBIC
+ self.resample = resample if resample is not None else PIL.Image.Resampling.BICUBIC
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
@@ -97,7 +105,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
self.codebook_do_resize = codebook_do_resize
self.codebook_size = codebook_size
- self.codebook_resample = codebook_resample if codebook_resample is not None else PILImageResampling.LANCZOS
+ self.codebook_resample = codebook_resample if codebook_resample is not None else PIL.Image.Resampling.LANCZOS
self.codebook_do_center_crop = codebook_do_center_crop
self.codebook_crop_size = codebook_crop_size
self.codebook_do_map_pixels = codebook_do_map_pixels
@@ -113,6 +121,8 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
"do_resize": self.do_resize,
"size": self.size,
"resample": self.resample,
+ "do_rescale": self.do_rescale,
+ "rescale_factor": self.rescale_factor,
"do_center_crop": self.do_center_crop,
"crop_size": self.crop_size,
"input_size_patches": self.input_size_patches,
@@ -133,7 +143,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
}
def get_expected_image_size(self):
- return (self.size, self.size) if not isinstance(self.size, tuple) else self.size
+ return (self.size["height"], self.size["width"])
def get_expected_mask_size(self):
return (
@@ -143,10 +153,7 @@ class FlavaFeatureExtractionTester(unittest.TestCase):
)
def get_expected_codebook_image_size(self):
- if not isinstance(self.codebook_size, tuple):
- return (self.codebook_size, self.codebook_size)
- else:
- return self.codebook_size
+ return (self.codebook_size["height"], self.codebook_size["width"])
@require_torch
@@ -172,6 +179,8 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
self.assertTrue(hasattr(feature_extractor, "resample"))
self.assertTrue(hasattr(feature_extractor, "crop_size"))
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "do_rescale"))
+ self.assertTrue(hasattr(feature_extractor, "rescale_factor"))
self.assertTrue(hasattr(feature_extractor, "masking_generator"))
self.assertTrue(hasattr(feature_extractor, "codebook_do_resize"))
self.assertTrue(hasattr(feature_extractor, "codebook_size"))
@@ -192,7 +201,7 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
# create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
for image in image_inputs:
- self.assertIsInstance(image, Image.Image)
+ self.assertIsInstance(image, PIL.Image.Image)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
@@ -324,7 +333,7 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
# create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
for image in image_inputs:
- self.assertIsInstance(image, Image.Image)
+ self.assertIsInstance(image, PIL.Image.Image)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")
diff --git a/tests/models/flava/test_processor_flava.py b/tests/models/flava/test_processor_flava.py
index 21cc84d5f2..11b8a8add4 100644
--- a/tests/models/flava/test_processor_flava.py
+++ b/tests/models/flava/test_processor_flava.py
@@ -32,7 +32,7 @@ if is_vision_available():
from PIL import Image
from transformers import FlavaFeatureExtractor, FlavaProcessor
- from transformers.models.flava.feature_extraction_flava import (
+ from transformers.models.flava.image_processing_flava import (
FLAVA_CODEBOOK_MEAN,
FLAVA_CODEBOOK_STD,
FLAVA_IMAGE_MEAN,
@@ -69,7 +69,6 @@ class FlavaProcessorTest(unittest.TestCase):
"mask_group_max_aspect_ratio": None,
"codebook_do_resize": True,
"codebook_size": 112,
- "codebook_resample": None,
"codebook_do_center_crop": True,
"codebook_crop_size": 112,
"codebook_do_map_pixels": True,
diff --git a/tests/models/imagegpt/test_feature_extraction_imagegpt.py b/tests/models/imagegpt/test_feature_extraction_imagegpt.py
index 1dd3786759..0dd614840b 100644
--- a/tests/models/imagegpt/test_feature_extraction_imagegpt.py
+++ b/tests/models/imagegpt/test_feature_extraction_imagegpt.py
@@ -47,9 +47,10 @@ class ImageGPTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=18,
+ size=None,
do_normalize=True,
):
+ size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
diff --git a/tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py b/tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py
index 59c30d779c..0a3528e16c 100644
--- a/tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py
+++ b/tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py
@@ -43,9 +43,10 @@ class LayoutLMv2FeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=18,
+ size=None,
apply_ocr=True,
):
+ size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -97,8 +98,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -112,8 +113,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -132,8 +133,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -144,8 +145,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -164,8 +165,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -176,8 +177,8 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -210,12 +211,4 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
encoding = feature_extractor(image, return_tensors="pt")
- self.assertEqual(
- encoding.pixel_values.shape,
- (
- 1,
- 3,
- 224,
- 224,
- ),
- )
+ self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
diff --git a/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py b/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py
index 9d05a4b665..68a32e6e8f 100644
--- a/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py
+++ b/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py
@@ -43,9 +43,10 @@ class LayoutLMv3FeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=18,
+ size=None,
apply_ocr=True,
):
+ size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -97,8 +98,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -112,8 +113,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -132,8 +133,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -144,8 +145,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -164,8 +165,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -176,8 +177,8 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
diff --git a/tests/models/levit/test_feature_extraction_levit.py b/tests/models/levit/test_feature_extraction_levit.py
index 98a704b97a..138542d85d 100644
--- a/tests/models/levit/test_feature_extraction_levit.py
+++ b/tests/models/levit/test_feature_extraction_levit.py
@@ -43,12 +43,15 @@ class LevitFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=18,
+ size=None,
do_center_crop=True,
+ crop_size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
+ size = size if size is not None else {"shortest_edge": 18}
+ crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -58,6 +61,7 @@ class LevitFeatureExtractionTester(unittest.TestCase):
self.do_resize = do_resize
self.size = size
self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
@@ -70,6 +74,7 @@ class LevitFeatureExtractionTester(unittest.TestCase):
"do_resize": self.do_resize,
"do_center_crop": self.do_center_crop,
"size": self.size,
+ "crop_size": self.crop_size,
}
@@ -113,8 +118,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -125,8 +130,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -145,8 +150,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -157,8 +162,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -177,8 +182,8 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -189,7 +194,7 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
diff --git a/tests/models/mobilevit/test_feature_extraction_mobilevit.py b/tests/models/mobilevit/test_feature_extraction_mobilevit.py
index f13267c541..1a2f52d0da 100644
--- a/tests/models/mobilevit/test_feature_extraction_mobilevit.py
+++ b/tests/models/mobilevit/test_feature_extraction_mobilevit.py
@@ -43,11 +43,13 @@ class MobileViTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=20,
+ size=None,
do_center_crop=True,
- crop_size=18,
+ crop_size=None,
do_flip_channel_order=True,
):
+ size = size if size is not None else {"shortest_edge": 20}
+ crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -109,8 +111,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -121,8 +123,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -141,8 +143,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -153,8 +155,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -173,8 +175,8 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -185,7 +187,7 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.crop_size,
- self.feature_extract_tester.crop_size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
diff --git a/tests/models/poolformer/test_feature_extraction_poolformer.py b/tests/models/poolformer/test_feature_extraction_poolformer.py
index bb65835d5d..41599989b1 100644
--- a/tests/models/poolformer/test_feature_extraction_poolformer.py
+++ b/tests/models/poolformer/test_feature_extraction_poolformer.py
@@ -41,12 +41,15 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize_and_center_crop=True,
- size=30,
+ size=None,
crop_pct=0.9,
+ crop_size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
+ size = size if size is not None else {"shortest_edge": 30}
+ crop_size = crop_size if crop_size is not None else {"height": 30, "width": 30}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -55,6 +58,7 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
self.do_resize_and_center_crop = do_resize_and_center_crop
self.size = size
self.crop_pct = crop_pct
+ self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
@@ -64,6 +68,7 @@ class PoolFormerFeatureExtractionTester(unittest.TestCase):
"size": self.size,
"do_resize_and_center_crop": self.do_resize_and_center_crop,
"crop_pct": self.crop_pct,
+ "crop_size": self.crop_size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
@@ -111,8 +116,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -123,8 +128,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -143,8 +148,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -155,8 +160,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -175,8 +180,8 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -187,7 +192,7 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
diff --git a/tests/models/segformer/test_feature_extraction_segformer.py b/tests/models/segformer/test_feature_extraction_segformer.py
index 75083012d8..b3ba44862b 100644
--- a/tests/models/segformer/test_feature_extraction_segformer.py
+++ b/tests/models/segformer/test_feature_extraction_segformer.py
@@ -43,12 +43,13 @@ class SegformerFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=30,
+ size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
- reduce_labels=False,
+ do_reduce_labels=False,
):
+ size = size if size is not None else {"height": 30, "width": 30}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -59,7 +60,7 @@ class SegformerFeatureExtractionTester(unittest.TestCase):
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
- self.reduce_labels = reduce_labels
+ self.do_reduce_labels = do_reduce_labels
def prepare_feat_extract_dict(self):
return {
@@ -68,7 +69,7 @@ class SegformerFeatureExtractionTester(unittest.TestCase):
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
- "reduce_labels": self.reduce_labels,
+ "do_reduce_labels": self.do_reduce_labels,
}
@@ -112,7 +113,7 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
- self.assertTrue(hasattr(feature_extractor, "reduce_labels"))
+ self.assertTrue(hasattr(feature_extractor, "do_reduce_labels"))
def test_batch_feature(self):
pass
@@ -132,8 +133,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -144,8 +145,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -164,8 +165,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -176,8 +177,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -196,8 +197,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -208,8 +209,8 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -230,16 +231,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@@ -253,16 +254,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
self.feature_extract_tester.batch_size,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@@ -278,16 +279,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
@@ -303,16 +304,16 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
(
2,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
2,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
diff --git a/tests/models/videomae/test_feature_extraction_videomae.py b/tests/models/videomae/test_feature_extraction_videomae.py
index cfe00f51e5..eebdbb7cc3 100644
--- a/tests/models/videomae/test_feature_extraction_videomae.py
+++ b/tests/models/videomae/test_feature_extraction_videomae.py
@@ -44,11 +44,15 @@ class VideoMAEFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=18,
+ size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
+ crop_size=None,
):
+ size = size if size is not None else {"shortest_edge": 18}
+ crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
+
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -61,6 +65,7 @@ class VideoMAEFeatureExtractionTester(unittest.TestCase):
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
+ self.crop_size = crop_size
def prepare_feat_extract_dict(self):
return {
@@ -69,6 +74,7 @@ class VideoMAEFeatureExtractionTester(unittest.TestCase):
"do_normalize": self.do_normalize,
"do_resize": self.do_resize,
"size": self.size,
+ "crop_size": self.crop_size,
}
@@ -91,6 +97,7 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "size"))
def test_batch_feature(self):
@@ -113,8 +120,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
1,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -126,8 +133,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -148,8 +155,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
1,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -161,8 +168,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -183,8 +190,8 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
1,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
@@ -196,7 +203,7 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_frames,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.crop_size["height"],
+ self.feature_extract_tester.crop_size["width"],
),
)
diff --git a/tests/models/vilt/test_feature_extraction_vilt.py b/tests/models/vilt/test_feature_extraction_vilt.py
index 62a9783c81..d2e0d2e803 100644
--- a/tests/models/vilt/test_feature_extraction_vilt.py
+++ b/tests/models/vilt/test_feature_extraction_vilt.py
@@ -43,12 +43,13 @@ class ViltFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=30,
+ size=None,
size_divisor=2,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
+ size = size if size is not None else {"shortest_edge": 30}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -78,18 +79,19 @@ class ViltFeatureExtractionTester(unittest.TestCase):
assuming do_resize is set to True with a scalar size and size_divisor.
"""
if not batched:
+ size = self.size["shortest_edge"]
image = image_inputs[0]
if isinstance(image, Image.Image):
w, h = image.size
else:
h, w = image.shape[1], image.shape[2]
- scale = self.size / min(w, h)
+ scale = size / min(w, h)
if h < w:
- newh, neww = self.size, scale * w
+ newh, neww = size, scale * w
else:
- newh, neww = scale * h, self.size
+ newh, neww = scale * h, size
- max_size = int((1333 / 800) * self.size)
+ max_size = int((1333 / 800) * size)
if max(newh, neww) > max_size:
scale = max_size / max(newh, neww)
newh = newh * scale
@@ -233,7 +235,7 @@ class ViltFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
def test_equivalence_pad_and_create_pixel_mask(self):
# Initialize feature_extractors
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
- feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
+ feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False, do_rescale=False)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
diff --git a/tests/models/vit/test_feature_extraction_vit.py b/tests/models/vit/test_feature_extraction_vit.py
index 2daf6452ff..e33b7361ab 100644
--- a/tests/models/vit/test_feature_extraction_vit.py
+++ b/tests/models/vit/test_feature_extraction_vit.py
@@ -43,11 +43,12 @@ class ViTFeatureExtractionTester(unittest.TestCase):
min_resolution=30,
max_resolution=400,
do_resize=True,
- size=18,
+ size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
+ size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -109,8 +110,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -121,8 +122,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -141,8 +142,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -153,8 +154,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -173,8 +174,8 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
1,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
@@ -185,7 +186,7 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
- self.feature_extract_tester.size,
- self.feature_extract_tester.size,
+ self.feature_extract_tester.size["height"],
+ self.feature_extract_tester.size["width"],
),
)
diff --git a/tests/utils/test_image_processing_utils.py b/tests/utils/test_image_processing_utils.py
new file mode 100644
index 0000000000..afb6283e6e
--- /dev/null
+++ b/tests/utils/test_image_processing_utils.py
@@ -0,0 +1,71 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers.image_processing_utils import get_size_dict
+
+
+class ImageProcessingUtilsTester(unittest.TestCase):
+ def test_get_size_dict(self):
+ # Test a dict with the wrong keys raises an error
+ inputs = {"wrong_key": 224}
+ with self.assertRaises(ValueError):
+ get_size_dict(inputs)
+
+ inputs = {"height": 224}
+ with self.assertRaises(ValueError):
+ get_size_dict(inputs)
+
+ inputs = {"width": 224, "shortest_edge": 224}
+ with self.assertRaises(ValueError):
+ get_size_dict(inputs)
+
+ # Test a dict with the correct keys is returned as is
+ inputs = {"height": 224, "width": 224}
+ outputs = get_size_dict(inputs)
+ self.assertEqual(outputs, inputs)
+
+ inputs = {"shortest_edge": 224}
+ outputs = get_size_dict(inputs)
+ self.assertEqual(outputs, {"shortest_edge": 224})
+
+ inputs = {"longest_edge": 224, "shortest_edge": 224}
+ outputs = get_size_dict(inputs)
+ self.assertEqual(outputs, {"longest_edge": 224, "shortest_edge": 224})
+
+ # Test a single int value which represents (size, size)
+ outputs = get_size_dict(224)
+ self.assertEqual(outputs, {"height": 224, "width": 224})
+
+ # Test a single int value which represents the shortest edge
+ outputs = get_size_dict(224, default_to_square=False)
+ self.assertEqual(outputs, {"shortest_edge": 224})
+
+ # Test a tuple of ints which represents (height, width)
+ outputs = get_size_dict((150, 200))
+ self.assertEqual(outputs, {"height": 150, "width": 200})
+
+ # Test a tuple of ints which represents (width, height)
+ outputs = get_size_dict((150, 200), height_width_order=False)
+ self.assertEqual(outputs, {"height": 200, "width": 150})
+
+ # Test an int representing the shortest edge and max_size which represents the longest edge
+ outputs = get_size_dict(224, max_size=256, default_to_square=False)
+ self.assertEqual(outputs, {"shortest_edge": 224, "longest_edge": 256})
+
+ # Test int with default_to_square=True and max_size fails
+ with self.assertRaises(ValueError):
+ get_size_dict(224, max_size=256, default_to_square=True)