Add fast image processor SAM (#39385)
* add fast image processor sam * nits
This commit is contained in:
@@ -25,7 +25,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
SAM (Segment Anything Model) was proposed in [Segment Anything](https://huggingface.co/papers/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
|
||||
|
||||
The model can be used to predict segmentation masks of any object of interest given an input image.
|
||||
The model can be used to predict segmentation masks of any object of interest given an input image.
|
||||
|
||||

|
||||
|
||||
@@ -37,9 +37,9 @@ Tips:
|
||||
|
||||
- The model predicts binary masks that states the presence or not of the object of interest given an image.
|
||||
- The model predicts much better results if input 2D points and/or input bounding boxes are provided
|
||||
- You can prompt multiple points for the same image, and predict a single mask.
|
||||
- You can prompt multiple points for the same image, and predict a single mask.
|
||||
- Fine-tuning the model is not supported yet
|
||||
- According to the paper, textual input should be also supported. However, at this time of writing this seems not to be supported according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844).
|
||||
- According to the paper, textual input should be also supported. However, at this time of writing this seems not to be supported according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844).
|
||||
|
||||
|
||||
This model was contributed by [ybelkada](https://huggingface.co/ybelkada) and [ArthurZ](https://huggingface.co/ArthurZ).
|
||||
@@ -149,6 +149,11 @@ alt="drawing" width="900"/>
|
||||
[[autodoc]] SamImageProcessor
|
||||
|
||||
|
||||
## SamImageProcessorFast
|
||||
|
||||
[[autodoc]] SamImageProcessorFast
|
||||
|
||||
|
||||
## SamVisionModel
|
||||
|
||||
[[autodoc]] SamVisionModel
|
||||
|
||||
@@ -147,8 +147,8 @@ else:
|
||||
("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
|
||||
("sam", ("SamImageProcessor",)),
|
||||
("sam_hq", ("SamImageProcessor",)),
|
||||
("sam", ("SamImageProcessor", "SamImageProcessorFast")),
|
||||
("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")),
|
||||
("segformer", ("SegformerImageProcessor",)),
|
||||
("seggpt", ("SegGptImageProcessor",)),
|
||||
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
||||
|
||||
@@ -656,6 +656,15 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("vits", ("VitsTokenizer", None)),
|
||||
(
|
||||
"voxtral",
|
||||
(
|
||||
"MistralCommonTokenizer"
|
||||
if is_mistral_common_available()
|
||||
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
||||
),
|
||||
),
|
||||
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
|
||||
|
||||
@@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_sam import *
|
||||
from .image_processing_sam import *
|
||||
from .image_processing_sam_fast import *
|
||||
from .modeling_sam import *
|
||||
from .modeling_tf_sam import *
|
||||
from .processing_sam import *
|
||||
|
||||
@@ -387,6 +387,11 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
|
||||
return segmentation_map, original_size
|
||||
|
||||
def __call__(self, images, segmentation_maps=None, **kwargs):
|
||||
# Overrides the `__call__` method of the `BaseImageProcessor` class such that the images and segmentation maps can both
|
||||
# be passed in as positional arguments.
|
||||
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
def preprocess(
|
||||
self,
|
||||
|
||||
865
src/transformers/models/sam/image_processing_sam_fast.py
Normal file
865
src/transformers/models/sam/image_processing_sam_fast.py
Normal file
@@ -0,0 +1,865 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for SAM."""
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...image_processing_utils import BatchFeature, get_size_dict
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
IMAGENET_DEFAULT_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
make_list_of_images,
|
||||
pil_torch_interpolation_mapping,
|
||||
validate_kwargs,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.nn import functional as F_t
|
||||
|
||||
if is_torchvision_available() and is_torchvision_v2_available():
|
||||
from torchvision.ops.boxes import batched_nms
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
elif is_torchvision_available():
|
||||
from torchvision.ops.boxes import batched_nms
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
r"""
|
||||
do_pad (`bool`, *optional*, defaults to `True`):
|
||||
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
|
||||
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
|
||||
pad_size (`dict[str, int]`, *optional*):
|
||||
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
|
||||
provided for preprocessing.
|
||||
mask_size (`dict[str, int]`, *optional*):
|
||||
The size `{"longest_edge": int}` to resize the segmentation maps to.
|
||||
mask_pad_size (`dict[str, int]`, *optional*):
|
||||
The size `{"height": int, "width": int}` to pad the segmentation maps to. Must be larger than any segmentation
|
||||
map size provided for preprocessing.
|
||||
"""
|
||||
|
||||
mask_size: Optional[dict[str, int]]
|
||||
do_pad: Optional[bool]
|
||||
pad_size: Optional[dict[str, int]]
|
||||
mask_pad_size: Optional[dict[str, int]]
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SamImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
image_mean = IMAGENET_DEFAULT_MEAN
|
||||
image_std = IMAGENET_DEFAULT_STD
|
||||
size = {"longest_edge": 1024}
|
||||
mask_size = {"longest_edge": 256}
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
|
||||
valid_kwargs = SamFastImageProcessorKwargs
|
||||
|
||||
do_pad = True
|
||||
pad_size = {"height": 1024, "width": 1024}
|
||||
mask_pad_size = {"height": 256, "width": 256}
|
||||
|
||||
def __init__(self, **kwargs: Unpack[SamFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def pad_image(self, images: "torch.Tensor", pad_size: SizeDict):
|
||||
"""Pad images to the specified size."""
|
||||
output_height, output_width = pad_size.height, pad_size.width
|
||||
input_height, input_width = images.shape[-2:]
|
||||
pad_width = output_width - input_width
|
||||
pad_height = output_height - input_height
|
||||
padding = (0, 0, pad_width, pad_height)
|
||||
return F.pad(images, padding)
|
||||
|
||||
def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int):
|
||||
"""
|
||||
Compute the output size given input size and target long side length.
|
||||
"""
|
||||
oldh, oldw = old_shape
|
||||
scale = longest_edge * 1.0 / max(oldh, oldw)
|
||||
newh, neww = oldh * scale, oldw * scale
|
||||
newh = int(newh + 0.5)
|
||||
neww = int(neww + 0.5)
|
||||
return (newh, neww)
|
||||
|
||||
def resize(
|
||||
self, image: "torch.Tensor", size: SizeDict, interpolation: Optional["F.InterpolationMode"], **kwargs
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Resize an image to `(size["height"], size["width"])`.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`dict[str, int]`):
|
||||
Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
|
||||
edge of the image will be resized to the specified size, while the other edge will be resized to
|
||||
maintain the aspect ratio.
|
||||
interpolation:
|
||||
`F.InterpolationMode` filter to use when resizing the image e.g. `F.InterpolationMode.BICUBIC`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The resized image.
|
||||
"""
|
||||
if not size.longest_edge:
|
||||
raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
|
||||
input_size = image.shape[-2:]
|
||||
output_height, output_width = self._get_preprocess_shape(input_size, size.longest_edge)
|
||||
return super().resize(
|
||||
image, size=SizeDict(height=output_height, width=output_width), interpolation=interpolation, **kwargs
|
||||
)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
do_pad: bool,
|
||||
pad_size: SizeDict,
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
if do_pad:
|
||||
stacked_images = self.pad_image(stacked_images, pad_size)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
|
||||
return processed_images
|
||||
|
||||
def _preprocess_segmentation_maps(
|
||||
self,
|
||||
segmentation_maps,
|
||||
**kwargs,
|
||||
):
|
||||
"""Preprocesses segmentation maps."""
|
||||
processed_segmentation_maps = []
|
||||
for segmentation_map in segmentation_maps:
|
||||
segmentation_map = self._process_image(
|
||||
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
|
||||
)
|
||||
|
||||
if segmentation_map.ndim == 2:
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
processed_segmentation_maps.append(segmentation_map)
|
||||
|
||||
kwargs["do_rescale"] = False
|
||||
kwargs["do_normalize"] = False
|
||||
kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST]
|
||||
kwargs["size"] = kwargs.pop("mask_size")
|
||||
kwargs["pad_size"] = kwargs.pop("mask_pad_size")
|
||||
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
|
||||
|
||||
processed_segmentation_maps = processed_segmentation_maps.squeeze(1) # Remove channel dimension
|
||||
|
||||
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
|
||||
return processed_segmentation_maps
|
||||
|
||||
def _further_process_kwargs(
|
||||
self,
|
||||
size: Optional[SizeDict] = None,
|
||||
pad_size: Optional[SizeDict] = None,
|
||||
mask_size: Optional[SizeDict] = None,
|
||||
mask_pad_size: Optional[SizeDict] = None,
|
||||
default_to_square: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Update kwargs that need further processing before being validated
|
||||
Can be overridden by subclasses to customize the processing of kwargs.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if size is not None:
|
||||
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
|
||||
if pad_size is not None:
|
||||
pad_size = SizeDict(**get_size_dict(pad_size, param_name="pad_size"))
|
||||
if mask_size is not None:
|
||||
mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size"))
|
||||
if mask_pad_size is not None:
|
||||
mask_pad_size = SizeDict(**get_size_dict(mask_pad_size, param_name="mask_pad_size"))
|
||||
if isinstance(image_mean, list):
|
||||
image_mean = tuple(image_mean)
|
||||
if isinstance(image_std, list):
|
||||
image_std = tuple(image_std)
|
||||
if data_format is None:
|
||||
data_format = ChannelDimension.FIRST
|
||||
|
||||
kwargs["size"] = size
|
||||
kwargs["pad_size"] = pad_size
|
||||
kwargs["mask_size"] = mask_size
|
||||
kwargs["mask_pad_size"] = mask_pad_size
|
||||
kwargs["default_to_square"] = default_to_square
|
||||
kwargs["image_mean"] = image_mean
|
||||
kwargs["image_std"] = image_std
|
||||
kwargs["data_format"] = data_format
|
||||
|
||||
return kwargs
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
segmentation_maps: Optional[ImageInput] = None,
|
||||
**kwargs: Unpack[SamFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
r"""
|
||||
segmentation_maps (`ImageInput`, *optional*):
|
||||
The segmentation maps to preprocess.
|
||||
"""
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
# Prepare input images
|
||||
images = self._prepare_input_images(
|
||||
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||
)
|
||||
|
||||
# Prepare segmentation maps
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)
|
||||
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
resample = kwargs.pop("resample")
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
||||
)
|
||||
|
||||
# Pop kwargs that are not needed in _preprocess
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
original_sizes = [image.shape[-2:] for image in images]
|
||||
|
||||
images = self._preprocess(
|
||||
images=images,
|
||||
**kwargs,
|
||||
)
|
||||
reshaped_input_sizes = [image.shape[-2:] for image in images]
|
||||
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = self._preprocess_segmentation_maps(
|
||||
segmentation_maps=segmentation_maps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return BatchFeature(
|
||||
data={
|
||||
"pixel_values": images,
|
||||
"labels": segmentation_maps,
|
||||
"original_sizes": original_sizes,
|
||||
"reshaped_input_sizes": reshaped_input_sizes,
|
||||
},
|
||||
tensor_type=kwargs["return_tensors"],
|
||||
)
|
||||
|
||||
return BatchFeature(
|
||||
data={
|
||||
"pixel_values": images,
|
||||
"original_sizes": original_sizes,
|
||||
"reshaped_input_sizes": reshaped_input_sizes,
|
||||
},
|
||||
tensor_type=kwargs["return_tensors"],
|
||||
)
|
||||
|
||||
def generate_crop_boxes(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
target_size,
|
||||
crop_n_layers: int = 0,
|
||||
overlap_ratio: float = 512 / 1500,
|
||||
points_per_crop: Optional[int] = 32,
|
||||
crop_n_points_downscale_factor: Optional[list[int]] = 1,
|
||||
device: Optional["torch.device"] = None,
|
||||
):
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Input original image
|
||||
target_size (`int`):
|
||||
Target size of the resized image
|
||||
crop_n_layers (`int`, *optional*, defaults to 0):
|
||||
If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
|
||||
each layer has 2**i_layer number of image crops.
|
||||
overlap_ratio (`float`, *optional*, defaults to 512/1500):
|
||||
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
points_per_crop (`int`, *optional*, defaults to 32):
|
||||
Number of points to sample from each crop.
|
||||
crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1):
|
||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
device (`torch.device`, *optional*, defaults to None):
|
||||
Device to use for the computation. If None, cpu will be used.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
return_tensors (`str`, *optional*, defaults to `pt`):
|
||||
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
||||
"""
|
||||
image = self._process_image(image)
|
||||
crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
|
||||
image,
|
||||
target_size,
|
||||
crop_n_layers,
|
||||
overlap_ratio,
|
||||
points_per_crop,
|
||||
crop_n_points_downscale_factor,
|
||||
)
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
crop_boxes = crop_boxes.to(device)
|
||||
points_per_crop = points_per_crop.to(device)
|
||||
# cropped_images stays as torch.Tensor
|
||||
input_labels = input_labels.to(device)
|
||||
|
||||
return crop_boxes, points_per_crop, cropped_images, input_labels
|
||||
|
||||
def filter_masks(
|
||||
self,
|
||||
masks,
|
||||
iou_scores,
|
||||
original_size,
|
||||
cropped_box_image,
|
||||
pred_iou_thresh=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
mask_threshold=0,
|
||||
stability_score_offset=1,
|
||||
):
|
||||
"""
|
||||
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
||||
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
||||
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
||||
bounding boxes and pad the predicted masks if necessary.
|
||||
|
||||
Args:
|
||||
masks (`torch.Tensor`):
|
||||
Input masks.
|
||||
iou_scores (`torch.Tensor`):
|
||||
List of IoU scores.
|
||||
original_size (`tuple[int,int]`):
|
||||
Size of the original image.
|
||||
cropped_box_image (`torch.Tensor`):
|
||||
The cropped image.
|
||||
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
||||
The threshold for the iou scores.
|
||||
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
||||
The threshold for the stability score.
|
||||
mask_threshold (`float`, *optional*, defaults to 0):
|
||||
The threshold for the predicted masks.
|
||||
stability_score_offset (`float`, *optional*, defaults to 1):
|
||||
The offset for the stability score used in the `_compute_stability_score` method.
|
||||
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
iou_scores = iou_scores.flatten(0, 1)
|
||||
masks = masks.flatten(0, 1)
|
||||
|
||||
if masks.shape[0] != iou_scores.shape[0]:
|
||||
raise ValueError("masks and iou_scores must have the same batch size.")
|
||||
|
||||
if masks.device != iou_scores.device:
|
||||
iou_scores = iou_scores.to(masks.device)
|
||||
|
||||
batch_size = masks.shape[0]
|
||||
|
||||
keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
|
||||
|
||||
if pred_iou_thresh > 0.0:
|
||||
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
|
||||
|
||||
# compute stability score
|
||||
if stability_score_thresh > 0.0:
|
||||
stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
|
||||
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
||||
|
||||
scores = iou_scores[keep_mask]
|
||||
masks = masks[keep_mask]
|
||||
|
||||
# binarize masks
|
||||
masks = masks > mask_threshold
|
||||
converted_boxes = _batched_mask_to_box(masks)
|
||||
|
||||
keep_mask = ~_is_box_near_crop_edge(
|
||||
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
|
||||
)
|
||||
|
||||
scores = scores[keep_mask]
|
||||
masks = masks[keep_mask]
|
||||
converted_boxes = converted_boxes[keep_mask]
|
||||
|
||||
masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
|
||||
# conversion to rle is necessary to run non-maximum suppression
|
||||
masks = _mask_to_rle(masks)
|
||||
|
||||
return masks, scores, converted_boxes
|
||||
|
||||
def post_process_masks(
|
||||
self,
|
||||
masks,
|
||||
original_sizes,
|
||||
reshaped_input_sizes,
|
||||
mask_threshold=0.0,
|
||||
binarize=True,
|
||||
pad_size=None,
|
||||
):
|
||||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Args:
|
||||
masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
|
||||
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
||||
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
||||
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
||||
width) format.
|
||||
reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
|
||||
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||||
The threshold to use for binarizing the masks.
|
||||
binarize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to binarize the masks.
|
||||
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
||||
The target size the images were padded to before being passed to the model. If None, the target size is
|
||||
assumed to be the processor's `pad_size`.
|
||||
Returns:
|
||||
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
|
||||
is given by original_size.
|
||||
"""
|
||||
pad_size = self.size if pad_size is None else pad_size
|
||||
target_image_size = (pad_size["height"], pad_size["width"])
|
||||
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
|
||||
original_sizes = original_sizes.tolist()
|
||||
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
|
||||
reshaped_input_sizes = reshaped_input_sizes.tolist()
|
||||
|
||||
output_masks = []
|
||||
for i, original_size in enumerate(original_sizes):
|
||||
if isinstance(masks[i], np.ndarray):
|
||||
masks[i] = torch.from_numpy(masks[i])
|
||||
elif not isinstance(masks[i], torch.Tensor):
|
||||
raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
|
||||
interpolated_mask = F_t.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
|
||||
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
|
||||
interpolated_mask = F_t.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
|
||||
if binarize:
|
||||
interpolated_mask = interpolated_mask > mask_threshold
|
||||
output_masks.append(interpolated_mask)
|
||||
|
||||
return output_masks
|
||||
|
||||
def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
|
||||
"""
|
||||
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
|
||||
|
||||
Args:
|
||||
all_masks (`torch.Tensor`):
|
||||
List of all predicted segmentation masks
|
||||
all_scores (`torch.Tensor`):
|
||||
List of all predicted iou scores
|
||||
all_boxes (`torch.Tensor`):
|
||||
List of all bounding boxes of the predicted masks
|
||||
crops_nms_thresh (`float`):
|
||||
Threshold for NMS (Non Maximum Suppression) algorithm.
|
||||
"""
|
||||
return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
||||
|
||||
|
||||
def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecessary cast to torch.int64
|
||||
intersections = (
|
||||
(masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
)
|
||||
unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
stability_scores = intersections / unions
|
||||
return stability_scores
|
||||
|
||||
|
||||
def _mask_to_rle(input_mask: "torch.Tensor"):
|
||||
"""
|
||||
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
||||
"""
|
||||
# Put in fortran order and flatten height and width
|
||||
batch_size, height, width = input_mask.shape
|
||||
input_mask = input_mask.permute(0, 2, 1).flatten(1)
|
||||
|
||||
# Compute change indices
|
||||
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
|
||||
change_indices = diff.nonzero()
|
||||
|
||||
# Encode run length
|
||||
out = []
|
||||
for i in range(batch_size):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
||||
if len(cur_idxs) == 0:
|
||||
# No changes => either all 0 or all 1
|
||||
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
|
||||
if input_mask[i, 0] == 0:
|
||||
out.append({"size": [height, width], "counts": [height * width]})
|
||||
else:
|
||||
out.append({"size": [height, width], "counts": [0, height * width]})
|
||||
continue
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if input_mask[i, 0] == 0 else [0]
|
||||
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
|
||||
out.append({"size": [height, width], "counts": counts})
|
||||
return out
|
||||
|
||||
|
||||
def _batched_mask_to_box(masks: "torch.Tensor"):
|
||||
"""
|
||||
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
||||
corresponds the following required indices:
|
||||
- LEFT: left hand side of the bounding box
|
||||
- TOP: top of the bounding box
|
||||
- RIGHT: right of the bounding box
|
||||
- BOTTOM: bottom of the bounding box
|
||||
|
||||
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
|
||||
is channel_1 x channel_2 x ... x 4.
|
||||
|
||||
Args:
|
||||
- masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
|
||||
if torch.numel(masks) == 0:
|
||||
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||
|
||||
# Normalize shape to Cxheightxwidth
|
||||
shape = masks.shape
|
||||
height, width = shape[-2:]
|
||||
|
||||
# Get top and bottom edges
|
||||
in_height, _ = torch.max(masks, dim=-1)
|
||||
in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
|
||||
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
||||
in_height_coords = in_height_coords + height * (~in_height)
|
||||
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width, _ = torch.max(masks, dim=-2)
|
||||
in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
|
||||
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
||||
in_width_coords = in_width_coords + width * (~in_width)
|
||||
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
||||
out = out * (~empty_filter).unsqueeze(-1)
|
||||
|
||||
# Return to original shape
|
||||
out = out.reshape(*shape[:-2], 4)
|
||||
return out
|
||||
|
||||
|
||||
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
||||
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||
|
||||
left, top, _, _ = crop_box
|
||||
offset = torch.tensor([[left, top, left, top]], device=boxes.device)
|
||||
# Check if boxes has a channel dimension
|
||||
if len(boxes.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
boxes = (boxes + offset).float()
|
||||
|
||||
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
||||
return torch.any(near_crop_edge, dim=1)
|
||||
|
||||
|
||||
def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int):
|
||||
left, top, right, bottom = crop_box
|
||||
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
|
||||
return masks
|
||||
# Coordinate transform masks
|
||||
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
|
||||
pad = (left, pad_x - left, top, pad_y - top)
|
||||
return torch.nn.functional.pad(masks, pad, value=0)
|
||||
|
||||
|
||||
def _generate_crop_boxes(
|
||||
image,
|
||||
target_size: int, # Is it tuple here?
|
||||
crop_n_layers: int = 0,
|
||||
overlap_ratio: float = 512 / 1500,
|
||||
points_per_crop: Optional[int] = 32,
|
||||
crop_n_points_downscale_factor: Optional[list[int]] = 1,
|
||||
) -> tuple[list[list[int]], list[int]]:
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||
|
||||
Args:
|
||||
image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
|
||||
Image to generate crops for.
|
||||
target_size (`int`):
|
||||
Size of the smallest crop.
|
||||
crop_n_layers (`int`, *optional*):
|
||||
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
|
||||
to run, where each layer has 2**i_layer number of image crops.
|
||||
overlap_ratio (`int`, *optional*):
|
||||
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
|
||||
image length. Later layers with more crops scale down this overlap.
|
||||
points_per_crop (`int`, *optional*):
|
||||
Number of points to sample per crop.
|
||||
crop_n_points_downscale_factor (`int`, *optional*):
|
||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
|
||||
if isinstance(image, list):
|
||||
raise ValueError("Only one image is allowed for crop generation.")
|
||||
original_size = image.shape[-2:]
|
||||
|
||||
points_grid = []
|
||||
for i in range(crop_n_layers + 1):
|
||||
n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
|
||||
points_grid.append(_build_point_grid(n_points))
|
||||
|
||||
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
|
||||
|
||||
cropped_images, point_grid_per_crop = _generate_crop_images(
|
||||
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
||||
)
|
||||
crop_boxes = torch.tensor(crop_boxes)
|
||||
crop_boxes = crop_boxes.float()
|
||||
points_per_crop = torch.stack(point_grid_per_crop)
|
||||
points_per_crop = points_per_crop.unsqueeze(0).permute(0, 2, 1, 3)
|
||||
cropped_images = torch.stack(cropped_images)
|
||||
|
||||
input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.int64)
|
||||
|
||||
return crop_boxes, points_per_crop, cropped_images, input_labels
|
||||
|
||||
|
||||
def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
|
||||
"""
|
||||
Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
|
||||
consists of the following required indices:
|
||||
- X: X coordinate of the top left of the bounding box
|
||||
- Y: Y coordinate of the top left of the bounding box
|
||||
- W: width of the bounding box
|
||||
- H: height of the bounding box
|
||||
"""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_height, im_width = original_size
|
||||
short_side = min(im_height, im_width)
|
||||
|
||||
# Original image
|
||||
crop_boxes.append([0, 0, im_width, im_height])
|
||||
layer_idxs.append(0)
|
||||
for i_layer in range(crop_n_layers):
|
||||
n_crops_per_side = 2 ** (i_layer + 1)
|
||||
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
||||
|
||||
crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
|
||||
crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
|
||||
|
||||
crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
|
||||
crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
|
||||
|
||||
for left, top in product(crop_box_x0, crop_box_y0):
|
||||
box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
|
||||
crop_boxes.append(box)
|
||||
layer_idxs.append(i_layer + 1)
|
||||
|
||||
return crop_boxes, layer_idxs
|
||||
|
||||
|
||||
def _build_point_grid(n_per_side: int) -> torch.Tensor:
|
||||
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = torch.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = torch.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
points_y = torch.tile(points_one_side[:, None], (1, n_per_side))
|
||||
points = torch.stack([points_x, points_y], dim=-1).reshape(-1, 2)
|
||||
return points
|
||||
|
||||
|
||||
def _generate_crop_images(
|
||||
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
|
||||
):
|
||||
"""
|
||||
Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
|
||||
also passed.
|
||||
"""
|
||||
cropped_images = []
|
||||
total_points_per_crop = []
|
||||
for i, crop_box in enumerate(crop_boxes):
|
||||
left, top, right, bottom = crop_box
|
||||
cropped_im = image[:, top:bottom, left:right]
|
||||
|
||||
cropped_images.append(cropped_im)
|
||||
|
||||
cropped_im_size = cropped_im.shape[-2:]
|
||||
points_scale = torch.tensor(cropped_im_size).flip(dims=(0,)).unsqueeze(0)
|
||||
|
||||
points = points_grid[layer_idxs[i]] * points_scale
|
||||
normalized_points = _normalize_coordinates(target_size, points, original_size)
|
||||
total_points_per_crop.append(normalized_points)
|
||||
|
||||
return cropped_images, total_points_per_crop
|
||||
|
||||
|
||||
def _normalize_coordinates(
|
||||
target_size: int, coords: torch.Tensor, original_size: tuple[int, int], is_bounding_box=False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
|
||||
format.
|
||||
"""
|
||||
old_height, old_width = original_size
|
||||
|
||||
scale = target_size * 1.0 / max(old_height, old_width)
|
||||
new_height, new_width = old_height * scale, old_width * scale
|
||||
new_width = int(new_width + 0.5)
|
||||
new_height = int(new_height + 0.5)
|
||||
|
||||
coords = deepcopy(coords).float()
|
||||
|
||||
if is_bounding_box:
|
||||
coords = coords.reshape(-1, 2, 2)
|
||||
|
||||
coords[..., 0] = coords[..., 0] * (new_width / old_width)
|
||||
coords[..., 1] = coords[..., 1] * (new_height / old_height)
|
||||
|
||||
if is_bounding_box:
|
||||
coords = coords.reshape(-1, 4)
|
||||
|
||||
return coords
|
||||
|
||||
|
||||
def _rle_to_mask(rle: dict[str, Any]) -> torch.Tensor:
|
||||
"""Compute a binary mask from an uncompressed RLE."""
|
||||
height, width = rle["size"]
|
||||
mask = torch.empty(height * width, dtype=bool)
|
||||
idx = 0
|
||||
parity = False
|
||||
for count in rle["counts"]:
|
||||
mask[idx : idx + count] = parity
|
||||
idx += count
|
||||
parity = not parity
|
||||
mask = mask.reshape(width, height)
|
||||
return mask.transpose(0, 1) # Reshape to original shape
|
||||
|
||||
|
||||
def _post_process_for_mask_generation(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
||||
"""
|
||||
Perform NMS (Non Maximum Suppression) on the outputs.
|
||||
|
||||
Args:
|
||||
rle_masks (`torch.Tensor`):
|
||||
binary masks in the RLE format
|
||||
iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
|
||||
iou_scores predicted by the model
|
||||
mask_boxes (`torch.Tensor`):
|
||||
The bounding boxes corresponding to segmentation masks
|
||||
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
||||
NMS threshold.
|
||||
"""
|
||||
keep_by_nms = batched_nms(
|
||||
boxes=mask_boxes.float(),
|
||||
scores=iou_scores,
|
||||
idxs=torch.zeros(mask_boxes.shape[0]),
|
||||
iou_threshold=amg_crops_nms_thresh,
|
||||
)
|
||||
|
||||
iou_scores = iou_scores[keep_by_nms]
|
||||
rle_masks = [rle_masks[i] for i in keep_by_nms]
|
||||
mask_boxes = mask_boxes[keep_by_nms]
|
||||
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
||||
|
||||
return masks, iou_scores, rle_masks, mask_boxes
|
||||
|
||||
|
||||
__all__ = ["SamImageProcessorFast"]
|
||||
301
tests/models/sam/test_image_processing_sam.py
Normal file
301
tests/models/sam/test_image_processing_sam.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torchvision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import SamImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import SamImageProcessorFast
|
||||
|
||||
|
||||
class SamImageProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_pad=True,
|
||||
pad_size=None,
|
||||
mask_size=None,
|
||||
mask_pad_size=None,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
):
|
||||
size = size if size is not None else {"longest_edge": 20}
|
||||
pad_size = pad_size if pad_size is not None else {"height": 20, "width": 20}
|
||||
mask_size = mask_size if mask_size is not None else {"longest_edge": 12}
|
||||
mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 12, "width": 12}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_pad = do_pad
|
||||
self.pad_size = pad_size
|
||||
self.mask_size = mask_size
|
||||
self.mask_pad_size = mask_pad_size
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_normalize": self.do_normalize,
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_pad": self.do_pad,
|
||||
"pad_size": self.pad_size,
|
||||
"mask_size": self.mask_size,
|
||||
"mask_pad_size": self.mask_pad_size,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.pad_size["height"], self.pad_size["width"]
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_single_inputs
|
||||
def prepare_semantic_single_inputs():
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
example = ds[0]
|
||||
return example["image"], example["map"]
|
||||
|
||||
|
||||
# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_batch_inputs
|
||||
def prepare_semantic_batch_inputs():
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
return list(ds["image"][:2]), list(ds["map"][:2])
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class SamImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = SamImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = SamImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = SamImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||
self.assertTrue(hasattr(image_processing, "pad_size"))
|
||||
self.assertTrue(hasattr(image_processing, "mask_size"))
|
||||
self.assertTrue(hasattr(image_processing, "mask_pad_size"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing_class = image_processing_class(**self.image_processor_dict)
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"longest_edge": 20})
|
||||
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict, size={"longest_edge": 42})
|
||||
self.assertEqual(image_processor.size, {"longest_edge": 42})
|
||||
|
||||
def test_call_segmentation_maps(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processor
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
maps = []
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
maps.append(torch.zeros(image.shape[-2:]).long())
|
||||
|
||||
# Test not batched input
|
||||
encoding = image_processor(image_inputs[0], maps[0], return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.pad_size["height"],
|
||||
self.image_processor_tester.pad_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.mask_pad_size["height"],
|
||||
self.image_processor_tester.mask_pad_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test batched
|
||||
encoding = image_processor(image_inputs, maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.pad_size["height"],
|
||||
self.image_processor_tester.pad_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.mask_pad_size["height"],
|
||||
self.image_processor_tester.mask_pad_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test not batched input (PIL images)
|
||||
image, segmentation_map = prepare_semantic_single_inputs()
|
||||
|
||||
encoding = image_processor(image, segmentation_map, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.pad_size["height"],
|
||||
self.image_processor_tester.pad_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.mask_pad_size["height"],
|
||||
self.image_processor_tester.mask_pad_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test batched input (PIL images)
|
||||
images, segmentation_maps = prepare_semantic_batch_inputs()
|
||||
|
||||
encoding = image_processor(images, segmentation_maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.pad_size["height"],
|
||||
self.image_processor_tester.pad_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.mask_pad_size["height"],
|
||||
self.image_processor_tester.mask_pad_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
def test_slow_fast_equivalence(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
dummy_image, dummy_map = prepare_semantic_single_inputs()
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
||||
image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1))
|
||||
|
||||
def test_slow_fast_equivalence_batched(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
dummy_images, dummy_maps = prepare_semantic_batch_inputs()
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
Reference in New Issue
Block a user