Add fast image processor SAM (#39385)

* add fast image processor sam

* nits
This commit is contained in:
Yoni Gozlan
2025-07-18 13:27:16 -04:00
committed by GitHub
parent 561a79a2f4
commit de0dd3139d
7 changed files with 1191 additions and 5 deletions

View File

@@ -149,6 +149,11 @@ alt="drawing" width="900"/>
[[autodoc]] SamImageProcessor
## SamImageProcessorFast
[[autodoc]] SamImageProcessorFast
## SamVisionModel
[[autodoc]] SamVisionModel

View File

@@ -147,8 +147,8 @@ else:
("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
("sam", ("SamImageProcessor",)),
("sam_hq", ("SamImageProcessor",)),
("sam", ("SamImageProcessor", "SamImageProcessorFast")),
("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")),
("segformer", ("SegformerImageProcessor",)),
("seggpt", ("SegGptImageProcessor",)),
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),

View File

@@ -656,6 +656,15 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("vits", ("VitsTokenizer", None)),
(
"voxtral",
(
"MistralCommonTokenizer"
if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
),
),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),

View File

@@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_sam import *
from .image_processing_sam import *
from .image_processing_sam_fast import *
from .modeling_sam import *
from .modeling_tf_sam import *
from .processing_sam import *

View File

@@ -387,6 +387,11 @@ class SamImageProcessor(BaseImageProcessor):
return segmentation_map, original_size
def __call__(self, images, segmentation_maps=None, **kwargs):
# Overrides the `__call__` method of the `BaseImageProcessor` class such that the images and segmentation maps can both
# be passed in as positional arguments.
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
@filter_out_non_signature_kwargs()
def preprocess(
self,

View File

@@ -0,0 +1,865 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for SAM."""
import math
from copy import deepcopy
from itertools import product
from typing import Any, Optional, Union
import numpy as np
import torch
from ...image_processing_utils import BatchFeature, get_size_dict
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
make_list_of_images,
pil_torch_interpolation_mapping,
validate_kwargs,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
auto_docstring,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
if is_torch_available():
import torch
from torch.nn import functional as F_t
if is_torchvision_available() and is_torchvision_v2_available():
from torchvision.ops.boxes import batched_nms
from torchvision.transforms.v2 import functional as F
elif is_torchvision_available():
from torchvision.ops.boxes import batched_nms
from torchvision.transforms import functional as F
class SamFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
r"""
do_pad (`bool`, *optional*, defaults to `True`):
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
pad_size (`dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing.
mask_size (`dict[str, int]`, *optional*):
The size `{"longest_edge": int}` to resize the segmentation maps to.
mask_pad_size (`dict[str, int]`, *optional*):
The size `{"height": int, "width": int}` to pad the segmentation maps to. Must be larger than any segmentation
map size provided for preprocessing.
"""
mask_size: Optional[dict[str, int]]
do_pad: Optional[bool]
pad_size: Optional[dict[str, int]]
mask_pad_size: Optional[dict[str, int]]
@auto_docstring
class SamImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_DEFAULT_MEAN
image_std = IMAGENET_DEFAULT_STD
size = {"longest_edge": 1024}
mask_size = {"longest_edge": 256}
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
valid_kwargs = SamFastImageProcessorKwargs
do_pad = True
pad_size = {"height": 1024, "width": 1024}
mask_pad_size = {"height": 256, "width": 256}
def __init__(self, **kwargs: Unpack[SamFastImageProcessorKwargs]):
super().__init__(**kwargs)
def pad_image(self, images: "torch.Tensor", pad_size: SizeDict):
"""Pad images to the specified size."""
output_height, output_width = pad_size.height, pad_size.width
input_height, input_width = images.shape[-2:]
pad_width = output_width - input_width
pad_height = output_height - input_height
padding = (0, 0, pad_width, pad_height)
return F.pad(images, padding)
def _get_preprocess_shape(self, old_shape: tuple[int, int], longest_edge: int):
"""
Compute the output size given input size and target long side length.
"""
oldh, oldw = old_shape
scale = longest_edge * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
newh = int(newh + 0.5)
neww = int(neww + 0.5)
return (newh, neww)
def resize(
self, image: "torch.Tensor", size: SizeDict, interpolation: Optional["F.InterpolationMode"], **kwargs
) -> "torch.Tensor":
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`dict[str, int]`):
Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
edge of the image will be resized to the specified size, while the other edge will be resized to
maintain the aspect ratio.
interpolation:
`F.InterpolationMode` filter to use when resizing the image e.g. `F.InterpolationMode.BICUBIC`.
Returns:
`torch.Tensor`: The resized image.
"""
if not size.longest_edge:
raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
input_size = image.shape[-2:]
output_height, output_width = self._get_preprocess_shape(input_size, size.longest_edge)
return super().resize(
image, size=SizeDict(height=output_height, width=output_width), interpolation=interpolation, **kwargs
)
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
do_pad: bool,
pad_size: SizeDict,
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
if do_pad:
stacked_images = self.pad_image(stacked_images, pad_size)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return processed_images
def _preprocess_segmentation_maps(
self,
segmentation_maps,
**kwargs,
):
"""Preprocesses segmentation maps."""
processed_segmentation_maps = []
for segmentation_map in segmentation_maps:
segmentation_map = self._process_image(
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
)
if segmentation_map.ndim == 2:
segmentation_map = segmentation_map[None, ...]
processed_segmentation_maps.append(segmentation_map)
kwargs["do_rescale"] = False
kwargs["do_normalize"] = False
kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST]
kwargs["size"] = kwargs.pop("mask_size")
kwargs["pad_size"] = kwargs.pop("mask_pad_size")
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
processed_segmentation_maps = processed_segmentation_maps.squeeze(1) # Remove channel dimension
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
return processed_segmentation_maps
def _further_process_kwargs(
self,
size: Optional[SizeDict] = None,
pad_size: Optional[SizeDict] = None,
mask_size: Optional[SizeDict] = None,
mask_pad_size: Optional[SizeDict] = None,
default_to_square: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
data_format: Optional[ChannelDimension] = None,
**kwargs,
) -> dict:
"""
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if kwargs is None:
kwargs = {}
if size is not None:
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
if pad_size is not None:
pad_size = SizeDict(**get_size_dict(pad_size, param_name="pad_size"))
if mask_size is not None:
mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size"))
if mask_pad_size is not None:
mask_pad_size = SizeDict(**get_size_dict(mask_pad_size, param_name="mask_pad_size"))
if isinstance(image_mean, list):
image_mean = tuple(image_mean)
if isinstance(image_std, list):
image_std = tuple(image_std)
if data_format is None:
data_format = ChannelDimension.FIRST
kwargs["size"] = size
kwargs["pad_size"] = pad_size
kwargs["mask_size"] = mask_size
kwargs["mask_pad_size"] = mask_pad_size
kwargs["default_to_square"] = default_to_square
kwargs["image_mean"] = image_mean
kwargs["image_std"] = image_std
kwargs["data_format"] = data_format
return kwargs
@auto_docstring
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
**kwargs: Unpack[SamFastImageProcessorKwargs],
) -> BatchFeature:
r"""
segmentation_maps (`ImageInput`, *optional*):
The segmentation maps to preprocess.
"""
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# Prepare segmentation maps
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
resample = kwargs.pop("resample")
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
# Pop kwargs that are not needed in _preprocess
kwargs.pop("default_to_square")
kwargs.pop("data_format")
original_sizes = [image.shape[-2:] for image in images]
images = self._preprocess(
images=images,
**kwargs,
)
reshaped_input_sizes = [image.shape[-2:] for image in images]
if segmentation_maps is not None:
segmentation_maps = self._preprocess_segmentation_maps(
segmentation_maps=segmentation_maps,
**kwargs,
)
return BatchFeature(
data={
"pixel_values": images,
"labels": segmentation_maps,
"original_sizes": original_sizes,
"reshaped_input_sizes": reshaped_input_sizes,
},
tensor_type=kwargs["return_tensors"],
)
return BatchFeature(
data={
"pixel_values": images,
"original_sizes": original_sizes,
"reshaped_input_sizes": reshaped_input_sizes,
},
tensor_type=kwargs["return_tensors"],
)
def generate_crop_boxes(
self,
image: "torch.Tensor",
target_size,
crop_n_layers: int = 0,
overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[list[int]] = 1,
device: Optional["torch.device"] = None,
):
"""
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
Args:
image (`torch.Tensor`):
Input original image
target_size (`int`):
Target size of the resized image
crop_n_layers (`int`, *optional*, defaults to 0):
If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
each layer has 2**i_layer number of image crops.
overlap_ratio (`float`, *optional*, defaults to 512/1500):
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
points_per_crop (`int`, *optional*, defaults to 32):
Number of points to sample from each crop.
crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
device (`torch.device`, *optional*, defaults to None):
Device to use for the computation. If None, cpu will be used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
return_tensors (`str`, *optional*, defaults to `pt`):
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
"""
image = self._process_image(image)
crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
image,
target_size,
crop_n_layers,
overlap_ratio,
points_per_crop,
crop_n_points_downscale_factor,
)
if device is None:
device = torch.device("cpu")
crop_boxes = crop_boxes.to(device)
points_per_crop = points_per_crop.to(device)
# cropped_images stays as torch.Tensor
input_labels = input_labels.to(device)
return crop_boxes, points_per_crop, cropped_images, input_labels
def filter_masks(
self,
masks,
iou_scores,
original_size,
cropped_box_image,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
mask_threshold=0,
stability_score_offset=1,
):
"""
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
bounding boxes and pad the predicted masks if necessary.
Args:
masks (`torch.Tensor`):
Input masks.
iou_scores (`torch.Tensor`):
List of IoU scores.
original_size (`tuple[int,int]`):
Size of the original image.
cropped_box_image (`torch.Tensor`):
The cropped image.
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
The threshold for the iou scores.
stability_score_thresh (`float`, *optional*, defaults to 0.95):
The threshold for the stability score.
mask_threshold (`float`, *optional*, defaults to 0):
The threshold for the predicted masks.
stability_score_offset (`float`, *optional*, defaults to 1):
The offset for the stability score used in the `_compute_stability_score` method.
"""
original_height, original_width = original_size
iou_scores = iou_scores.flatten(0, 1)
masks = masks.flatten(0, 1)
if masks.shape[0] != iou_scores.shape[0]:
raise ValueError("masks and iou_scores must have the same batch size.")
if masks.device != iou_scores.device:
iou_scores = iou_scores.to(masks.device)
batch_size = masks.shape[0]
keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
if pred_iou_thresh > 0.0:
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
# compute stability score
if stability_score_thresh > 0.0:
stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
scores = iou_scores[keep_mask]
masks = masks[keep_mask]
# binarize masks
masks = masks > mask_threshold
converted_boxes = _batched_mask_to_box(masks)
keep_mask = ~_is_box_near_crop_edge(
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
)
scores = scores[keep_mask]
masks = masks[keep_mask]
converted_boxes = converted_boxes[keep_mask]
masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
# conversion to rle is necessary to run non-maximum suppression
masks = _mask_to_rle(masks)
return masks, scores, converted_boxes
def post_process_masks(
self,
masks,
original_sizes,
reshaped_input_sizes,
mask_threshold=0.0,
binarize=True,
pad_size=None,
):
"""
Remove padding and upscale masks to the original image size.
Args:
masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
The original sizes of each image before it was resized to the model's expected input shape, in (height,
width) format.
reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the masks.
pad_size (`int`, *optional*, defaults to `self.pad_size`):
The target size the images were padded to before being passed to the model. If None, the target size is
assumed to be the processor's `pad_size`.
Returns:
(`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
is given by original_size.
"""
pad_size = self.size if pad_size is None else pad_size
target_image_size = (pad_size["height"], pad_size["width"])
if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
original_sizes = original_sizes.tolist()
if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
reshaped_input_sizes = reshaped_input_sizes.tolist()
output_masks = []
for i, original_size in enumerate(original_sizes):
if isinstance(masks[i], np.ndarray):
masks[i] = torch.from_numpy(masks[i])
elif not isinstance(masks[i], torch.Tensor):
raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
interpolated_mask = F_t.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
interpolated_mask = F_t.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)
if binarize:
interpolated_mask = interpolated_mask > mask_threshold
output_masks.append(interpolated_mask)
return output_masks
def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
"""
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
Args:
all_masks (`torch.Tensor`):
List of all predicted segmentation masks
all_scores (`torch.Tensor`):
List of all predicted iou scores
all_boxes (`torch.Tensor`):
List of all bounding boxes of the predicted masks
crops_nms_thresh (`float`):
Threshold for NMS (Non Maximum Suppression) algorithm.
"""
return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh)
def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
intersections = (
(masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
)
unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
stability_scores = intersections / unions
return stability_scores
def _mask_to_rle(input_mask: "torch.Tensor"):
"""
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
"""
# Put in fortran order and flatten height and width
batch_size, height, width = input_mask.shape
input_mask = input_mask.permute(0, 2, 1).flatten(1)
# Compute change indices
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
change_indices = diff.nonzero()
# Encode run length
out = []
for i in range(batch_size):
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
if len(cur_idxs) == 0:
# No changes => either all 0 or all 1
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
if input_mask[i, 0] == 0:
out.append({"size": [height, width], "counts": [height * width]})
else:
out.append({"size": [height, width], "counts": [0, height * width]})
continue
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
out.append({"size": [height, width], "counts": counts})
return out
def _batched_mask_to_box(masks: "torch.Tensor"):
"""
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
corresponds the following required indices:
- LEFT: left hand side of the bounding box
- TOP: top of the bounding box
- RIGHT: right of the bounding box
- BOTTOM: bottom of the bounding box
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
is channel_1 x channel_2 x ... x 4.
Args:
- masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
# Normalize shape to Cxheightxwidth
shape = masks.shape
height, width = shape[-2:]
# Get top and bottom edges
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + height * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
# Get left and right edges
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + width * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
# Return to original shape
out = out.reshape(*shape[:-2], 4)
return out
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
left, top, _, _ = crop_box
offset = torch.tensor([[left, top, left, top]], device=boxes.device)
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = offset.unsqueeze(1)
boxes = (boxes + offset).float()
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
return torch.any(near_crop_edge, dim=1)
def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int):
left, top, right, bottom = crop_box
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
pad = (left, pad_x - left, top, pad_y - top)
return torch.nn.functional.pad(masks, pad, value=0)
def _generate_crop_boxes(
image,
target_size: int, # Is it tuple here?
crop_n_layers: int = 0,
overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[list[int]] = 1,
) -> tuple[list[list[int]], list[int]]:
"""
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
Args:
image (Union[`numpy.ndarray`, `PIL.Image`, `torch.Tensor`]):
Image to generate crops for.
target_size (`int`):
Size of the smallest crop.
crop_n_layers (`int`, *optional*):
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers
to run, where each layer has 2**i_layer number of image crops.
overlap_ratio (`int`, *optional*):
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the
image length. Later layers with more crops scale down this overlap.
points_per_crop (`int`, *optional*):
Number of points to sample per crop.
crop_n_points_downscale_factor (`int`, *optional*):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
if isinstance(image, list):
raise ValueError("Only one image is allowed for crop generation.")
original_size = image.shape[-2:]
points_grid = []
for i in range(crop_n_layers + 1):
n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
points_grid.append(_build_point_grid(n_points))
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
cropped_images, point_grid_per_crop = _generate_crop_images(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
)
crop_boxes = torch.tensor(crop_boxes)
crop_boxes = crop_boxes.float()
points_per_crop = torch.stack(point_grid_per_crop)
points_per_crop = points_per_crop.unsqueeze(0).permute(0, 2, 1, 3)
cropped_images = torch.stack(cropped_images)
input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.int64)
return crop_boxes, points_per_crop, cropped_images, input_labels
def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
"""
Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format
consists of the following required indices:
- X: X coordinate of the top left of the bounding box
- Y: Y coordinate of the top left of the bounding box
- W: width of the bounding box
- H: height of the bounding box
"""
crop_boxes, layer_idxs = [], []
im_height, im_width = original_size
short_side = min(im_height, im_width)
# Original image
crop_boxes.append([0, 0, im_width, im_height])
layer_idxs.append(0)
for i_layer in range(crop_n_layers):
n_crops_per_side = 2 ** (i_layer + 1)
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))
crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]
for left, top in product(crop_box_x0, crop_box_y0):
box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
crop_boxes.append(box)
layer_idxs.append(i_layer + 1)
return crop_boxes, layer_idxs
def _build_point_grid(n_per_side: int) -> torch.Tensor:
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
points_one_side = torch.linspace(offset, 1 - offset, n_per_side)
points_x = torch.tile(points_one_side[None, :], (n_per_side, 1))
points_y = torch.tile(points_one_side[:, None], (1, n_per_side))
points = torch.stack([points_x, points_y], dim=-1).reshape(-1, 2)
return points
def _generate_crop_images(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
):
"""
Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
also passed.
"""
cropped_images = []
total_points_per_crop = []
for i, crop_box in enumerate(crop_boxes):
left, top, right, bottom = crop_box
cropped_im = image[:, top:bottom, left:right]
cropped_images.append(cropped_im)
cropped_im_size = cropped_im.shape[-2:]
points_scale = torch.tensor(cropped_im_size).flip(dims=(0,)).unsqueeze(0)
points = points_grid[layer_idxs[i]] * points_scale
normalized_points = _normalize_coordinates(target_size, points, original_size)
total_points_per_crop.append(normalized_points)
return cropped_images, total_points_per_crop
def _normalize_coordinates(
target_size: int, coords: torch.Tensor, original_size: tuple[int, int], is_bounding_box=False
) -> torch.Tensor:
"""
Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width)
format.
"""
old_height, old_width = original_size
scale = target_size * 1.0 / max(old_height, old_width)
new_height, new_width = old_height * scale, old_width * scale
new_width = int(new_width + 0.5)
new_height = int(new_height + 0.5)
coords = deepcopy(coords).float()
if is_bounding_box:
coords = coords.reshape(-1, 2, 2)
coords[..., 0] = coords[..., 0] * (new_width / old_width)
coords[..., 1] = coords[..., 1] * (new_height / old_height)
if is_bounding_box:
coords = coords.reshape(-1, 4)
return coords
def _rle_to_mask(rle: dict[str, Any]) -> torch.Tensor:
"""Compute a binary mask from an uncompressed RLE."""
height, width = rle["size"]
mask = torch.empty(height * width, dtype=bool)
idx = 0
parity = False
for count in rle["counts"]:
mask[idx : idx + count] = parity
idx += count
parity = not parity
mask = mask.reshape(width, height)
return mask.transpose(0, 1) # Reshape to original shape
def _post_process_for_mask_generation(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
"""
Perform NMS (Non Maximum Suppression) on the outputs.
Args:
rle_masks (`torch.Tensor`):
binary masks in the RLE format
iou_scores (`torch.Tensor` of shape (nb_masks, 1)):
iou_scores predicted by the model
mask_boxes (`torch.Tensor`):
The bounding boxes corresponding to segmentation masks
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
NMS threshold.
"""
keep_by_nms = batched_nms(
boxes=mask_boxes.float(),
scores=iou_scores,
idxs=torch.zeros(mask_boxes.shape[0]),
iou_threshold=amg_crops_nms_thresh,
)
iou_scores = iou_scores[keep_by_nms]
rle_masks = [rle_masks[i] for i in keep_by_nms]
mask_boxes = mask_boxes[keep_by_nms]
masks = [_rle_to_mask(rle) for rle in rle_masks]
return masks, iou_scores, rle_masks, mask_boxes
__all__ = ["SamImageProcessorFast"]

View File

@@ -0,0 +1,301 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from datasets import load_dataset
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torchvision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from transformers import SamImageProcessor
if is_torchvision_available():
from transformers import SamImageProcessorFast
class SamImageProcessingTester:
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_pad=True,
pad_size=None,
mask_size=None,
mask_pad_size=None,
do_resize=True,
size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"longest_edge": 20}
pad_size = pad_size if pad_size is not None else {"height": 20, "width": 20}
mask_size = mask_size if mask_size is not None else {"longest_edge": 12}
mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 12, "width": 12}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_pad = do_pad
self.pad_size = pad_size
self.mask_size = mask_size
self.mask_pad_size = mask_pad_size
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
def prepare_image_processor_dict(self):
return {
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_normalize": self.do_normalize,
"do_resize": self.do_resize,
"size": self.size,
"do_pad": self.do_pad,
"pad_size": self.pad_size,
"mask_size": self.mask_size,
"mask_pad_size": self.mask_pad_size,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.pad_size["height"], self.pad_size["width"]
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_single_inputs
def prepare_semantic_single_inputs():
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
example = ds[0]
return example["image"], example["map"]
# Copied from transformers.tests.models.beit.test_image_processing_beit.prepare_semantic_batch_inputs
def prepare_semantic_batch_inputs():
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
return list(ds["image"][:2]), list(ds["map"][:2])
@require_torch
@require_vision
class SamImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SamImageProcessor if is_vision_available() else None
fast_image_processing_class = SamImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = SamImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "pad_size"))
self.assertTrue(hasattr(image_processing, "mask_size"))
self.assertTrue(hasattr(image_processing, "mask_pad_size"))
def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list:
image_processing_class = image_processing_class(**self.image_processor_dict)
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"longest_edge": 20})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size={"longest_edge": 42})
self.assertEqual(image_processor.size, {"longest_edge": 42})
def test_call_segmentation_maps(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processor
image_processor = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
maps = []
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
maps.append(torch.zeros(image.shape[-2:]).long())
# Test not batched input
encoding = image_processor(image_inputs[0], maps[0], return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
1,
self.image_processor_tester.num_channels,
self.image_processor_tester.pad_size["height"],
self.image_processor_tester.pad_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.image_processor_tester.mask_pad_size["height"],
self.image_processor_tester.mask_pad_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test batched
encoding = image_processor(image_inputs, maps, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
self.image_processor_tester.pad_size["height"],
self.image_processor_tester.pad_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.mask_pad_size["height"],
self.image_processor_tester.mask_pad_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test not batched input (PIL images)
image, segmentation_map = prepare_semantic_single_inputs()
encoding = image_processor(image, segmentation_map, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
1,
self.image_processor_tester.num_channels,
self.image_processor_tester.pad_size["height"],
self.image_processor_tester.pad_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.image_processor_tester.mask_pad_size["height"],
self.image_processor_tester.mask_pad_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test batched input (PIL images)
images, segmentation_maps = prepare_semantic_batch_inputs()
encoding = image_processor(images, segmentation_maps, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
2,
self.image_processor_tester.num_channels,
self.image_processor_tester.pad_size["height"],
self.image_processor_tester.pad_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
2,
self.image_processor_tester.mask_pad_size["height"],
self.image_processor_tester.mask_pad_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_image, dummy_map = prepare_semantic_single_inputs()
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3
)
self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1))
def test_slow_fast_equivalence_batched(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_images, dummy_maps = prepare_semantic_batch_inputs()
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)