TF port of the Segment Anything Model (SAM) (#22970)

* First commit

* Add auto-translation with GPT-4

* make fixup

* Add a functional layernorm for TF

* Add all the auxiliary imports etc.

* Add the extra processor and tests

* rebase to main

* Add all the needed fixes to the GPT code

* make fixup

* Make convolutions channels-last so they run on CPU

* make fixup

* Fix final issues

* Fix other models affected by test change

* Clarify comment on the sparse_prompt_embeddings check

* Refactor functional_layernorm, use shape_list in place of .shape in some places

* Remove deprecated torch-alike code

* Update tests/models/sam/test_modeling_tf_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/sam/test_modeling_tf_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Refactor processor with common methods and separated private methods

* make fixup

* Quietly delete the file that didn't do anything (sorry Sylvain)

* Refactor the processor tests into one file

* make fixup

* Clean up some unnecessary indirection

* Fix TF mask postprocessing

* Add more processor equivalence tests

* Refactor generate_crop_boxes to use framework-neutral np code

* Make the serving output correctly conditional

* Fix error message line length

* Use dict keys rather than indices internally in both TF and PT SAM call/forward

* Return dicts internally in the call/forward methods

* Revert changes to common tests and just override check_pt_tf_outputs

* Revert changes to other model tests

* Clarify comments for functional layernorm

* Add missing transpose from PT code

* Removed unused copied from in PT code

* Remove overrides for tests that don't exist in TF

* Fix transpose and update tests for PT and TF to check pred_masks

* Add training flag

* Update tests to use TF checkpoints

* Update index.mdx

* Add missing cross-test decorator

* Remove optional extra asterisks

* Revert return_dict changes in PT code

* Update src/transformers/models/sam/modeling_tf_sam.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove None return annotations on init methods

* Update tests/models/sam/test_processor_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix input_boxes shapes

* make fixup

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Matt
2023-05-19 14:14:13 +01:00
committed by GitHub
parent 8aa8513f71
commit 1c460a5273
14 changed files with 2940 additions and 44 deletions

View File

@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
| RWKV | ❌ | ❌ | ✅ | ❌ | ❌ |
| SAM | ❌ | ❌ | ✅ | | ❌ |
| SAM | ❌ | ❌ | ✅ | | ❌ |
| SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |

View File

@@ -99,3 +99,9 @@ Resources:
[[autodoc]] SamModel
- forward
## TFSamModel
[[autodoc]] TFSamModel
- call

View File

@@ -3406,6 +3406,13 @@ else:
"TFRoFormerPreTrainedModel",
]
)
_import_structure["models.sam"].extend(
[
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
)
_import_structure["models.segformer"].extend(
[
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -6657,6 +6664,11 @@ if TYPE_CHECKING:
TFRoFormerModel,
TFRoFormerPreTrainedModel,
)
from .models.sam import (
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSamModel,
TFSamPreTrainedModel,
)
from .models.segformer import (
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSegformerDecodeHead,

View File

@@ -76,6 +76,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("sam", "TFSamModel"),
("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"),
@@ -426,6 +427,11 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "TFSamModel"),
]
)
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
@@ -476,6 +482,14 @@ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING

View File

@@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
@@ -39,6 +45,17 @@ else:
"SamModel",
"SamPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_sam"] = [
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
@@ -66,6 +83,14 @@ if TYPE_CHECKING:
else:
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()

View File

@@ -34,7 +34,14 @@ from ...image_utils import (
to_numpy_array,
valid_images,
)
from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends
from ...utils import (
TensorType,
is_tf_available,
is_torch_available,
is_torchvision_available,
logging,
requires_backends,
)
if is_torch_available():
@@ -44,6 +51,12 @@ if is_torch_available():
if is_torchvision_available():
from torchvision.ops.boxes import batched_nms
if is_tf_available():
import tensorflow as tf
from tensorflow.experimental import numpy as tnp
from ...tf_utils import flatten, shape_list
logger = logging.get_logger(__name__)
@@ -372,6 +385,61 @@ class SamImageProcessor(BaseImageProcessor):
return encoded_outputs
def post_process_masks(
self,
masks,
original_sizes,
reshaped_input_sizes,
mask_threshold=0.0,
binarize=True,
pad_size=None,
return_tensors="pt",
):
"""
Remove padding and upscale masks to the original image size.
Args:
masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
The original sizes of each image before it was resized to the model's expected input shape, in (height,
width) format.
reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the masks.
pad_size (`int`, *optional*, defaults to `self.pad_size`):
The target size the images were padded to before being passed to the model. If None, the target size is
assumed to be the processor's `pad_size`.
return_tensors (`str`, *optional*, defaults to `"pt"`):
If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
Returns:
(`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
(height, width) is given by original_size.
"""
if return_tensors == "pt":
return self._post_process_masks_pt(
masks=masks,
original_sizes=original_sizes,
reshaped_input_sizes=reshaped_input_sizes,
mask_threshold=mask_threshold,
binarize=binarize,
pad_size=pad_size,
)
elif return_tensors == "tf":
return self._post_process_masks_tf(
masks=masks,
original_sizes=original_sizes,
reshaped_input_sizes=reshaped_input_sizes,
mask_threshold=mask_threshold,
binarize=binarize,
pad_size=pad_size,
)
else:
raise ValueError("return_tensors must be either 'pt' or 'tf'")
def _post_process_masks_pt(
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
):
"""
@@ -418,21 +486,70 @@ class SamImageProcessor(BaseImageProcessor):
return output_masks
def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
def _post_process_masks_tf(
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
):
"""
Remove padding and upscale masks to the original image size.
Args:
masks (`tf.Tensor`):
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
original_sizes (`tf.Tensor`):
The original size of the images before resizing for input to the model, in (height, width) format.
reshaped_input_sizes (`tf.Tensor`):
The size of the image input to the model, in (height, width) format. Used to remove padding.
mask_threshold (`float`, *optional*, defaults to 0.0):
The threshold to use for binarizing the masks.
binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the masks.
pad_size (`int`, *optional*, defaults to `self.pad_size`):
The target size the images were padded to before being passed to the model. If None, the target size is
assumed to be the processor's `pad_size`.
Returns:
(`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
given by original_size.
"""
requires_backends(self, ["tf"])
pad_size = self.pad_size if pad_size is None else pad_size
target_image_size = (pad_size["height"], pad_size["width"])
output_masks = []
for i, original_size in enumerate(original_sizes):
# tf.image expects NHWC, we transpose the NCHW inputs for it
mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
if binarize:
interpolated_mask = interpolated_mask > mask_threshold
# And then we transpose them back at the end
output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
return output_masks
def post_process_for_mask_generation(
self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"
):
"""
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
Args:
all_masks (`List[torch.Tensor]`):
all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`):
List of all predicted segmentation masks
all_scores (`List[torch.Tensor]`):
all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`):
List of all predicted iou scores
all_boxes (`List[torch.Tensor]`):
all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`):
List of all bounding boxes of the predicted masks
crops_nms_thresh (`float`):
Threshold for NMS (Non Maximum Suppression) algorithm.
return_tensors (`str`, *optional*, defaults to `pt`):
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
"""
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
if return_tensors == "pt":
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
elif return_tensors == "tf":
return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)
def generate_crop_boxes(
self,
@@ -443,6 +560,7 @@ class SamImageProcessor(BaseImageProcessor):
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[List[int]] = 1,
device: Optional["torch.device"] = None,
return_tensors: str = "pt",
):
"""
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
@@ -464,10 +582,35 @@ class SamImageProcessor(BaseImageProcessor):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
device (`torch.device`, *optional*, defaults to None):
Device to use for the computation. If None, cpu will be used.
return_tensors (`str`, *optional*, defaults to `pt`):
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
"""
return _generate_crop_boxes(
image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device
crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
image,
target_size,
crop_n_layers,
overlap_ratio,
points_per_crop,
crop_n_points_downscale_factor,
)
if return_tensors == "pt":
if device is None:
device = torch.device("cpu")
crop_boxes = torch.tensor(crop_boxes, device=device)
points_per_crop = torch.tensor(points_per_crop, device=device)
# cropped_images stays as np
input_labels = torch.tensor(input_labels, device=device)
elif return_tensors == "tf":
if device is not None:
raise ValueError("device is not a supported argument when return_tensors is tf!")
crop_boxes = tf.convert_to_tensor(crop_boxes)
points_per_crop = tf.convert_to_tensor(points_per_crop)
# cropped_images stays as np
input_labels = tf.convert_to_tensor(input_labels)
else:
raise ValueError("return_tensors must be either 'pt' or 'tf'.")
return crop_boxes, points_per_crop, cropped_images, input_labels
def filter_masks(
self,
@@ -479,6 +622,67 @@ class SamImageProcessor(BaseImageProcessor):
stability_score_thresh=0.95,
mask_threshold=0,
stability_score_offset=1,
return_tensors="pt",
):
"""
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
bounding boxes and pad the predicted masks if necessary.
Args:
masks (`Union[torch.Tensor, tf.Tensor]`):
Input masks.
iou_scores (`Union[torch.Tensor, tf.Tensor]`):
List of IoU scores.
original_size (`Tuple[int,int]`):
Size of the orginal image.
cropped_box_image (`np.array`):
The cropped image.
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
The threshold for the iou scores.
stability_score_thresh (`float`, *optional*, defaults to 0.95):
The threshold for the stability score.
mask_threshold (`float`, *optional*, defaults to 0):
The threshold for the predicted masks.
stability_score_offset (`float`, *optional*, defaults to 1):
The offset for the stability score used in the `_compute_stability_score` method.
return_tensors (`str`, *optional*, defaults to `pt`):
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
"""
if return_tensors == "pt":
return self._filter_masks_pt(
masks=masks,
iou_scores=iou_scores,
original_size=original_size,
cropped_box_image=cropped_box_image,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
mask_threshold=mask_threshold,
stability_score_offset=stability_score_offset,
)
elif return_tensors == "tf":
return self._filter_masks_tf(
masks=masks,
iou_scores=iou_scores,
original_size=original_size,
cropped_box_image=cropped_box_image,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
mask_threshold=mask_threshold,
stability_score_offset=stability_score_offset,
)
def _filter_masks_pt(
self,
masks,
iou_scores,
original_size,
cropped_box_image,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
mask_threshold=0,
stability_score_offset=1,
):
"""
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
@@ -525,7 +729,7 @@ class SamImageProcessor(BaseImageProcessor):
# compute stability score
if stability_score_thresh > 0.0:
stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
scores = iou_scores[keep_mask]
@@ -549,8 +753,85 @@ class SamImageProcessor(BaseImageProcessor):
return masks, scores, converted_boxes
def _filter_masks_tf(
self,
masks,
iou_scores,
original_size,
cropped_box_image,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
mask_threshold=0,
stability_score_offset=1,
):
"""
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
bounding boxes and pad the predicted masks if necessary.
def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
Args:
masks (`tf.Tensor`):
Input masks.
iou_scores (`tf.Tensor`):
List of IoU scores.
original_size (`Tuple[int,int]`):
Size of the orginal image.
cropped_box_image (`np.array`):
The cropped image.
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
The threshold for the iou scores.
stability_score_thresh (`float`, *optional*, defaults to 0.95):
The threshold for the stability score.
mask_threshold (`float`, *optional*, defaults to 0):
The threshold for the predicted masks.
stability_score_offset (`float`, *optional*, defaults to 1):
The offset for the stability score used in the `_compute_stability_score` method.
"""
requires_backends(self, ["tf"])
original_height, original_width = original_size
iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])
if masks.shape[0] != iou_scores.shape[0]:
raise ValueError("masks and iou_scores must have the same batch size.")
batch_size = masks.shape[0]
keep_mask = tf.ones(batch_size, dtype=tf.bool)
if pred_iou_thresh > 0.0:
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
# compute stability score
if stability_score_thresh > 0.0:
stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
scores = iou_scores[keep_mask]
masks = masks[keep_mask]
# binarize masks
masks = masks > mask_threshold
converted_boxes = _batched_mask_to_box_tf(masks)
keep_mask = ~_is_box_near_crop_edge_tf(
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
)
scores = scores[keep_mask]
masks = masks[keep_mask]
converted_boxes = converted_boxes[keep_mask]
masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)
# conversion to rle is necessary to run non-maximum suppresion
masks = _mask_to_rle_tf(masks)
return masks, scores, converted_boxes
def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
# One mask is always contained inside the other.
# Save memory by preventing unnecesary cast to torch.int64
intersections = (
@@ -561,6 +842,17 @@ def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stabi
return stability_scores
def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
# Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure
# we get the right division results.
intersections = tf.count_nonzero(
masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
)
unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
stability_scores = intersections / unions
return stability_scores
def _build_point_grid(n_per_side: int) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
@@ -606,7 +898,6 @@ def _generate_crop_boxes(
overlap_ratio: float = 512 / 1500,
points_per_crop: Optional[int] = 32,
crop_n_points_downscale_factor: Optional[List[int]] = 1,
device: Optional["torch.device"] = None,
) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
@@ -626,11 +917,7 @@ def _generate_crop_boxes(
Number of points to sample per crop.
crop_n_points_downscale_factor (`int`, *optional*):
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
device (`torch.device`, *optional*):
Device to run the crop generation on. Defaults to CPU.
"""
if device is None:
device = torch.device("cpu")
if isinstance(image, list):
raise ValueError("Only one image is allowed for crop generation.")
@@ -648,12 +935,11 @@ def _generate_crop_boxes(
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
)
crop_boxes = torch.tensor(crop_boxes, dtype=torch.float32, device=device)
point_grid_per_crop = np.array([point_grid_per_crop])
points_per_crop = torch.tensor(point_grid_per_crop, device=device)
points_per_crop = points_per_crop.permute(0, 2, 1, 3)
crop_boxes = crop_boxes.astype(np.float32)
points_per_crop = np.array([point_grid_per_crop])
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.long, device=device)
input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
return crop_boxes, points_per_crop, cropped_images, input_labels
@@ -730,6 +1016,16 @@ def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
return torch.nn.functional.pad(masks, pad, value=0)
def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int):
left, top, right, bottom = crop_box
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
pad = (left, pad_x - left, top, pad_y - top)
return tf.pad(masks, pad, constant_values=0)
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
@@ -748,6 +1044,24 @@ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
return torch.any(near_crop_edge, dim=1)
def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)
left, top, _, _ = crop_box
offset = tf.convert_to_tensor([[left, top, left, top]])
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = tf.expand_dims(offset, 1)
boxes = tf.cast(boxes + offset, tf.float32)
near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
return tf.reduce_any(near_crop_edge, axis=1)
def _batched_mask_to_box(masks: "torch.Tensor"):
"""
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
@@ -797,6 +1111,54 @@ def _batched_mask_to_box(masks: "torch.Tensor"):
return out
def _batched_mask_to_box_tf(masks: "tf.Tensor"):
"""
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
corresponds the following required indices:
- LEFT: left hand side of the bounding box
- TOP: top of the bounding box
- RIGHT: right of the bounding box
- BOTTOM: bottom of the bounding box
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
is channel_1 x channel_2 x ... x 4.
Args:
- masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`)
"""
if tf.size(masks) == 0:
return tf.zeros([*masks.shape[:-2], 4])
# Normalize shape to Cxheightxwidth
shape = shape_list(masks)
height, width = shape[-2:]
# Get top and bottom edges
in_height = tf.reduce_max(masks, axis=-1)
in_height_coords = in_height * tf.range(height)[None, :]
bottom_edges = tf.reduce_max(in_height_coords, axis=-1)
in_height_coords = in_height_coords + height * (~in_height)
top_edges = tf.reduce_min(in_height_coords, axis=-1)
# Get left and right edges
in_width, _ = tf.reduce_max(masks, axis=-2)
in_width_coords = in_width * tf.range(width)[None, :]
right_edges, _ = tf.reduce_max(in_width_coords, axis=-1)
in_width_coords = in_width_coords + width * (~in_width)
left_edges, _ = tf.reduce_min(in_width_coords, axis=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
out = out * tf.expand_dims(~empty_filter, -1)
# Return to original shape
out = tf.reshape(out, *shape[:-2], 4)
return out
def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
"""
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
@@ -820,6 +1182,29 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
return out
def _mask_to_rle_tf(input_mask: "tf.Tensor"):
"""
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
"""
# Put in fortran order and flatten height and width
batch_size, height, width = input_mask.shape
input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)
# Compute change indices
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
change_indices = tf.where(diff)
# Encode run length
out = []
for i in range(batch_size):
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
out.append({"size": [height, width], "counts": counts})
return out
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
"""Compute a binary mask from an uncompressed RLE."""
height, width = rle["size"]
@@ -836,7 +1221,7 @@ def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
"""
Perform NMS (Non Maxium Suppression) on the outputs.
Perform NMS (Non Maximum Suppression) on the outputs.
Args:
rle_masks (`torch.Tensor`):
@@ -861,3 +1246,32 @@ def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=
masks = [_rle_to_mask(rle) for rle in rle_masks]
return masks, iou_scores, rle_masks, mask_boxes
def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
"""
Perform NMS (Non Maximum Suppression) on the outputs.
Args:
rle_masks (`tf.Tensor`):
binary masks in the RLE format
iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
iou_scores predicted by the model
mask_boxes (`tf.Tensor`):
The bounding boxes corresponding to segmentation masks
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
NMS threshold.
"""
keep_by_nms = tf.image.combined_non_max_suppression(
boxes=mask_boxes.float(),
scores=iou_scores,
idxs=torch.zeros(mask_boxes.shape[0]),
iou_threshold=amg_crops_nms_thresh,
)
iou_scores = iou_scores[keep_by_nms]
rle_masks = [rle_masks[i] for i in keep_by_nms]
mask_boxes = mask_boxes[keep_by_nms]
masks = [_rle_to_mask(rle) for rle in rle_masks]
return masks, iou_scores, rle_masks, mask_boxes

View File

@@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput):
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from src.models.modeling_vit_mae.ViTMAEPatchEmbeddings with ViTMAEPatchEmbeddings->SamVisionEmbeddings,x->embeddings
class SamPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
@@ -198,7 +197,7 @@ class SamAttention(nn.Module):
values.
"""
def __init__(self, config, downsample_rate=None) -> None:
def __init__(self, config, downsample_rate=None):
super().__init__()
self.hidden_size = config.hidden_size
@@ -252,7 +251,7 @@ class SamAttention(nn.Module):
class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False) -> None:
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
A transformer block with four layers:
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
@@ -476,7 +475,7 @@ class SamMaskDecoder(nn.Module):
the embeddings of the mask inputs
multimask_output (bool):
Whether to return multiple masks or a single mask.
output_attentions (bool, **optional**):
output_attentions (bool, *optional*):
Whether or not to return the attentions tensors of all attention layers.
"""
batch_size, num_channels, height, width = image_embeddings.shape
@@ -668,11 +667,11 @@ class SamPromptEncoder(nn.Module):
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (`torch.Tensor`, **optionnal**):
points (`torch.Tensor`, *optional*):
point coordinates and labels to embed.
boxes (`torch.Tensor`, **optionnal**):
boxes (`torch.Tensor`, *optional*):
boxes to embed
masks (`torch.Tensor`, **optionnal**):
masks (`torch.Tensor`, *optional*):
masks to embed
"""
sparse_embeddings = None
@@ -707,7 +706,7 @@ class SamPromptEncoder(nn.Module):
class SamVisionAttention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(self, config, window_size) -> None:
def __init__(self, config, window_size):
super().__init__()
input_size = (
(config.image_size // config.patch_size, config.image_size // config.patch_size)
@@ -845,7 +844,7 @@ class SamVisionAttention(nn.Module):
class SamVisionLayer(nn.Module):
def __init__(self, config, window_size) -> None:
def __init__(self, config, window_size):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = SamVisionAttention(config, window_size)
@@ -1166,7 +1165,7 @@ SAM_INPUTS_DOCSTRING = r"""
class SamModel(SamPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config) -> None:
def __init__(self, config):
super().__init__(config)
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
@@ -1334,7 +1333,6 @@ class SamModel(SamPreTrainedModel):
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_attentions = None
mask_decoder_attentions = None
vision_hidden_states = None
if pixel_values is not None:
@@ -1359,7 +1357,8 @@ class SamModel(SamPreTrainedModel):
"The batch size of the image embeddings and the input points must be the same. ",
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
" if you want to pass multiple points for the same image, make sure that you passed ",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
)
sparse_embeddings, dense_embeddings = self.prompt_encoder(

File diff suppressed because it is too large Load Diff

View File

@@ -22,12 +22,15 @@ import numpy as np
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_torch_available
from ...utils import TensorType, is_tf_available, is_torch_available
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
class SamProcessor(ProcessorMixin):
r"""
@@ -72,7 +75,7 @@ class SamProcessor(ProcessorMixin):
# pop arguments that are not used in the foward but used nevertheless
original_sizes = encoding_image_processor["original_sizes"]
if isinstance(original_sizes, torch.Tensor):
if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor
original_sizes = original_sizes.numpy()
input_points, input_labels, input_boxes = self._check_and_preprocess_points(
@@ -139,18 +142,30 @@ class SamProcessor(ProcessorMixin):
input_boxes = torch.from_numpy(input_boxes)
# boxes batch size of 1 by default
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
elif return_tensors == "tf":
input_boxes = tf.convert_to_tensor(input_boxes)
# boxes batch size of 1 by default
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
encoding_image_processor.update({"input_boxes": input_boxes})
if input_points is not None:
if return_tensors == "pt":
input_points = torch.from_numpy(input_points)
# point batch size of 1 by default
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
elif return_tensors == "tf":
input_points = tf.convert_to_tensor(input_points)
# point batch size of 1 by default
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
encoding_image_processor.update({"input_points": input_points})
if input_labels is not None:
if return_tensors == "pt":
input_labels = torch.from_numpy(input_labels)
# point batch size of 1 by default
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
elif return_tensors == "tf":
input_labels = tf.convert_to_tensor(input_labels)
# point batch size of 1 by default
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
encoding_image_processor.update({"input_labels": input_labels})
return encoding_image_processor
@@ -204,7 +219,7 @@ class SamProcessor(ProcessorMixin):
it is converted to a `numpy.ndarray` and then to a `list`.
"""
if input_points is not None:
if isinstance(input_points, torch.Tensor):
if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor
input_points = input_points.numpy().tolist()
if not isinstance(input_points, list) or not isinstance(input_points[0], list):
@@ -214,7 +229,7 @@ class SamProcessor(ProcessorMixin):
input_points = None
if input_labels is not None:
if isinstance(input_labels, torch.Tensor):
if hasattr(input_labels, "numpy"):
input_labels = input_labels.numpy().tolist()
if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
@@ -224,7 +239,7 @@ class SamProcessor(ProcessorMixin):
input_labels = None
if input_boxes is not None:
if isinstance(input_boxes, torch.Tensor):
if hasattr(input_boxes, "numpy"):
input_boxes = input_boxes.numpy().tolist()
if (

View File

@@ -70,6 +70,56 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
# This is a very simplified functional layernorm, designed to duplicate
# the functionality of PyTorch nn.functional.layer_norm when this is needed to port
# models in Transformers.
if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int):
raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.")
# Get mean and variance on the axis to be normalized
mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True)
if axis != -1:
# Reshape scale and weight to have the same rank as inputs, but with 1 dimensions
# on every dimension except axis
shape = [1] * inputs.shape.rank
shape[axis] = shape_list(inputs)[axis]
weight = tf.reshape(weight, shape)
bias = tf.reshape(bias, shape)
# Compute layer normalization using the batch_normalization
# function.
outputs = tf.nn.batch_normalization(
inputs,
mean,
variance,
offset=bias,
scale=weight,
variance_epsilon=epsilon,
)
return outputs
def flatten(input, start_dim=0, end_dim=-1):
# Replicates the behavior of torch.flatten in TF
# If end_dim or start_dim is negative, count them from the end
if end_dim < 0:
end_dim += input.shape.rank
if start_dim < 0:
start_dim += input.shape.rank
if start_dim == end_dim:
return input
in_shape = tf.shape(input)
flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1])
out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0)
return tf.reshape(input, out_shape)
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
"""
Invert an attention mask (e.g., switches 0. and 1.).

View File

@@ -2317,6 +2317,23 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFSamModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSamPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@@ -436,6 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_hidden_states_output(self):
pass
def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4):
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol)
@slow
def test_model_from_pretrained(self):
for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@@ -470,8 +473,10 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4))
self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4))
def test_inference_mask_generation_one_point_one_bb(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
@@ -491,8 +496,12 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4))
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4))
self.assertTrue(
torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4)
)
def test_inference_mask_generation_batched_points_batched_images(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")
@@ -514,6 +523,7 @@ class SamModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze().cpu()
masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
EXPECTED_SCORES = torch.tensor(
[
@@ -531,7 +541,9 @@ class SamModelIntegrationTest(unittest.TestCase):
],
]
)
EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406])
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
def test_inference_mask_generation_one_point_one_bb_zero(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge")

View File

@@ -0,0 +1,671 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow SAM model. """
import inspect
import unittest
import numpy as np
import requests
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
from transformers.testing_utils import require_tf, slow
from transformers.utils import is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import SamProcessor, TFSamModel
from transformers.models.sam.modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
class TFSamPromptEncoderTester:
def __init__(
self,
hidden_size=32,
input_image_size=24,
patch_size=2,
mask_input_channels=4,
num_point_embeddings=4,
hidden_act="gelu",
):
self.hidden_size = hidden_size
self.input_image_size = input_image_size
self.patch_size = patch_size
self.mask_input_channels = mask_input_channels
self.num_point_embeddings = num_point_embeddings
self.hidden_act = hidden_act
def get_config(self):
return SamPromptEncoderConfig(
image_size=self.input_image_size,
patch_size=self.patch_size,
mask_input_channels=self.mask_input_channels,
hidden_size=self.hidden_size,
num_point_embeddings=self.num_point_embeddings,
hidden_act=self.hidden_act,
)
def prepare_config_and_inputs(self):
dummy_points = floats_tensor([self.batch_size, 3, 2])
config = self.get_config()
return config, dummy_points
class TFSamMaskDecoderTester:
def __init__(
self,
hidden_size=32,
hidden_act="relu",
mlp_dim=64,
num_hidden_layers=2,
num_attention_heads=4,
attention_downsample_rate=2,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=32,
layer_norm_eps=1e-6,
):
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_dim = mlp_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_downsample_rate = attention_downsample_rate
self.num_multimask_outputs = num_multimask_outputs
self.iou_head_depth = iou_head_depth
self.iou_head_hidden_dim = iou_head_hidden_dim
self.layer_norm_eps = layer_norm_eps
def get_config(self):
return SamMaskDecoderConfig(
hidden_size=self.hidden_size,
hidden_act=self.hidden_act,
mlp_dim=self.mlp_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
attention_downsample_rate=self.attention_downsample_rate,
num_multimask_outputs=self.num_multimask_outputs,
iou_head_depth=self.iou_head_depth,
iou_head_hidden_dim=self.iou_head_hidden_dim,
layer_norm_eps=self.layer_norm_eps,
)
def prepare_config_and_inputs(self):
config = self.get_config()
dummy_inputs = {
"image_embedding": floats_tensor([self.batch_size, self.hidden_size]),
}
return config, dummy_inputs
class TFSamModelTester:
def __init__(
self,
parent,
hidden_size=36,
intermediate_size=72,
projection_dim=62,
output_channels=32,
num_hidden_layers=2,
num_attention_heads=4,
num_channels=3,
image_size=24,
patch_size=2,
hidden_act="gelu",
layer_norm_eps=1e-06,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
qkv_bias=True,
mlp_ratio=4.0,
use_abs_pos=True,
use_rel_pos=True,
rel_pos_zero_init=False,
window_size=14,
global_attn_indexes=[2, 5, 8, 11],
num_pos_feats=16,
mlp_dim=None,
batch_size=2,
):
self.parent = parent
self.image_size = image_size
self.patch_size = patch_size
self.output_channels = output_channels
self.num_channels = num_channels
self.hidden_size = hidden_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.qkv_bias = qkv_bias
self.mlp_ratio = mlp_ratio
self.use_abs_pos = use_abs_pos
self.use_rel_pos = use_rel_pos
self.rel_pos_zero_init = rel_pos_zero_init
self.window_size = window_size
self.global_attn_indexes = global_attn_indexes
self.num_pos_feats = num_pos_feats
self.mlp_dim = mlp_dim
self.batch_size = batch_size
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.prompt_encoder_tester = TFSamPromptEncoderTester()
self.mask_decoder_tester = TFSamMaskDecoderTester()
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
vision_config = SamVisionConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
projection_dim=self.projection_dim,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
initializer_range=self.initializer_range,
initializer_factor=self.initializer_factor,
output_channels=self.output_channels,
qkv_bias=self.qkv_bias,
mlp_ratio=self.mlp_ratio,
use_abs_pos=self.use_abs_pos,
use_rel_pos=self.use_rel_pos,
rel_pos_zero_init=self.rel_pos_zero_init,
window_size=self.window_size,
global_attn_indexes=self.global_attn_indexes,
num_pos_feats=self.num_pos_feats,
mlp_dim=self.mlp_dim,
)
prompt_encoder_config = self.prompt_encoder_tester.get_config()
mask_decoder_config = self.mask_decoder_tester.get_config()
return SamConfig(
vision_config=vision_config,
prompt_encoder_config=prompt_encoder_config,
mask_decoder_config=mask_decoder_config,
)
def create_and_check_model(self, config, pixel_values):
model = TFSamModel(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3))
def create_and_check_get_image_features(self, config, pixel_values):
model = TFSamModel(config=config)
result = model.get_image_embeddings(pixel_values)
self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12))
def create_and_check_get_image_hidden_states(self, config, pixel_values):
model = TFSamModel(config=config)
result = model.vision_encoder(
pixel_values,
output_hidden_states=True,
return_dict=True,
)
# after computing the convolutional features
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
result = model.vision_encoder(
pixel_values,
output_hidden_states=True,
return_dict=False,
)
# after computing the convolutional features
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFSamModel,) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFSamModel, "mask-generation": TFSamModel} if is_tf_available() else {}
)
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
):
return True
def setUp(self):
self.model_tester = TFSamModelTester(self)
self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
self.prompt_encoder_config_tester = ConfigTester(
self,
config_class=SamPromptEncoderConfig,
has_text_modality=False,
num_attention_heads=12,
num_hidden_layers=2,
)
self.mask_decoder_config_tester = ConfigTester(
self, config_class=SamMaskDecoderConfig, has_text_modality=False
)
def test_config(self):
self.vision_config_tester.run_common_tests()
self.prompt_encoder_config_tester.run_common_tests()
self.mask_decoder_config_tester.run_common_tests()
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_get_image_features(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_get_image_features(*config_and_inputs)
def test_image_hidden_states(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
expected_vision_attention_shape = (
self.model_tester.batch_size * self.model_tester.num_attention_heads,
196,
196,
)
expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
vision_attentions = outputs.vision_attentions
self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
mask_decoder_attentions = outputs.mask_decoder_attentions
self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
vision_attentions = outputs.vision_attentions
self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
mask_decoder_attentions = outputs.mask_decoder_attentions
self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
self.assertListEqual(
list(vision_attentions[0].shape[-4:]),
list(expected_vision_attention_shape),
)
self.assertListEqual(
list(mask_decoder_attentions[0].shape[-4:]),
list(expected_mask_decoder_attention_shape),
)
@unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
def test_hidden_states_output(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFSamModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-4, name="outputs", attributes=None):
super().check_pt_tf_outputs(
tf_outputs=tf_outputs,
pt_outputs=pt_outputs,
model_class=model_class,
tol=tol,
name=name,
attributes=attributes,
)
def prepare_image():
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
def prepare_dog_img():
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
@slow
class SamModelIntegrationTest(unittest.TestCase):
def test_inference_mask_generation_no_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
inputs = processor(images=raw_image, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=2e-4))
self.assertTrue(np.allclose(masks.numpy(), np.array([-6.6381, -6.0734, -7.5308]), atol=1e-2))
def test_inference_mask_generation_one_point_one_bb(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_boxes = [[[650, 900, 1000, 1250]]]
input_points = [[[820, 1080]]]
inputs = processor(images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(np.allclose(scores[-1], np.array(0.9935), atol=2e-4))
self.assertTrue(np.allclose(masks.numpy(), np.array([-21.5465, -23.1122, -22.3331]), atol=2e-2))
def test_inference_mask_generation_batched_points_batched_images(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_points = [
[[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
[[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
]
inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
masks = outputs.pred_masks[0, 0, 0, 0, :3]
EXPECTED_SCORES = np.array(
[
[
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
],
[
[0.8405, 0.6292, 0.3840],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
[0.9673, 0.9441, 0.9084],
],
]
)
EXPECTED_MASKS = np.array([-26.5424, -34.0901, -30.6406])
self.assertTrue(np.allclose(scores.numpy(), EXPECTED_SCORES, atol=1e-3))
self.assertTrue(np.allclose(masks.numpy(), EXPECTED_MASKS, atol=3e-2))
def test_inference_mask_generation_one_point_one_bb_zero(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_boxes = [[[620, 900, 1000, 1255]]]
input_points = [[[820, 1080]]]
labels = [[0]]
inputs = processor(
images=raw_image,
input_boxes=input_boxes,
input_points=input_points,
input_labels=labels,
return_tensors="tf",
)
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9689), atol=1e-4))
def test_inference_mask_generation_one_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_points = [[[400, 650]]]
input_labels = [[1]]
inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1], np.array(0.9712), atol=1e-4))
# With no label
input_points = [[[400, 650]]]
inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9712), atol=1e-4))
def test_inference_mask_generation_two_points(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_points = [[[400, 650], [800, 650]]]
input_labels = [[1, 1]]
inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4))
# no labels
inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4))
def test_inference_mask_generation_two_points_batched(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_points = [[[400, 650], [800, 650]], [[400, 650]]]
input_labels = [[1, 1], [1]]
inputs = processor(
images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="tf"
)
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[0][-1].numpy(), np.array(0.9936), atol=1e-4))
self.assertTrue(np.allclose(scores[1][-1], np.array(0.9716), atol=1e-4))
def test_inference_mask_generation_one_box(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
input_boxes = [[[75, 275, 1725, 850]]]
inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="tf")
outputs = model(**inputs)
scores = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.8686), atol=1e-4))
def test_inference_mask_generation_batched_image_one_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
raw_dog_image = prepare_dog_img()
input_points = [[[820, 1080]], [[220, 470]]]
inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores_batched = tf.squeeze(outputs.iou_scores)
input_points = [[[220, 470]]]
inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
scores_single = tf.squeeze(outputs.iou_scores)
self.assertTrue(np.allclose(scores_batched[1, :].numpy(), scores_single.numpy(), atol=1e-4))
def test_inference_mask_generation_two_points_point_batch(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
# fmt: off
input_points = tf.convert_to_tensor([[[400, 650]], [[220, 470]]])
# fmt: on
input_points = tf.expand_dims(input_points, 0)
inputs = processor(raw_image, input_points=input_points, return_tensors="tf")
outputs = model(**inputs)
iou_scores = outputs.iou_scores
self.assertTrue(iou_scores.shape == (1, 2, 3))
self.assertTrue(
np.allclose(
iou_scores.numpy(),
np.array([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]),
atol=1e-4,
rtol=1e-4,
)
)
def test_inference_mask_generation_three_boxes_point_batch(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
raw_image = prepare_image()
# fmt: off
input_boxes = tf.convert_to_tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]])
EXPECTED_IOU = np.array([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]])
# fmt: on
input_boxes = tf.expand_dims(input_boxes, 0)
inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="tf")
outputs = model(**inputs)
iou_scores = outputs.iou_scores
self.assertTrue(iou_scores.shape == (1, 3, 3))
self.assertTrue(np.allclose(iou_scores.numpy(), EXPECTED_IOU, atol=1e-4, rtol=1e-4))

View File

@@ -17,8 +17,14 @@ import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_torchvision, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.testing_utils import (
is_pt_tf_cross_test,
require_tf,
require_torch,
require_torchvision,
require_vision,
)
from transformers.utils import is_tf_available, is_torch_available, is_vision_available
if is_vision_available():
@@ -29,6 +35,9 @@ if is_vision_available():
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
@require_vision
@require_torchvision
@@ -110,3 +119,158 @@ class SamProcessorTest(unittest.TestCase):
dummy_masks = [[1, 0], [0, 1]]
with self.assertRaises(ValueError):
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
@require_vision
@require_tf
class TFSamProcessorTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
processor = SamProcessor(image_processor)
processor.save_pretrained(self.tmpdirname)
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def test_save_load_pretrained_additional_features(self):
processor = SamProcessor(image_processor=self.get_image_processor())
processor.save_pretrained(self.tmpdirname)
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
processor = SamProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0)
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
self.assertIsInstance(processor.image_processor, SamImageProcessor)
def test_image_processor(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
input_feat_extract = image_processor(image_input, return_tensors="np")
input_processor = processor(images=image_input, return_tensors="np")
input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor
input_feat_extract.pop("reshaped_input_sizes") # pop reshaped_input_sizes as it is popped in the processor
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
@require_tf
def test_post_process_masks(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
dummy_masks = [tf.ones((1, 3, 5, 5))]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf")
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
masks = processor.post_process_masks(
dummy_masks,
tf.convert_to_tensor(original_sizes),
tf.convert_to_tensor(reshaped_input_size),
return_tensors="tf",
)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
# should also work with np
dummy_masks = [np.ones((1, 3, 5, 5))]
masks = processor.post_process_masks(
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
)
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
dummy_masks = [[1, 0], [0, 1]]
with self.assertRaises(tf.errors.InvalidArgumentError):
masks = processor.post_process_masks(
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
)
@require_vision
@require_torchvision
class SamProcessorEquivalenceTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
processor = SamProcessor(image_processor)
processor.save_pretrained(self.tmpdirname)
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
@is_pt_tf_cross_test
def test_post_process_masks_equivalence(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
dummy_masks = np.random.randint(0, 2, size=(1, 3, 5, 5)).astype(np.float32)
tf_dummy_masks = [tf.convert_to_tensor(dummy_masks)]
pt_dummy_masks = [torch.tensor(dummy_masks)]
original_sizes = [[1764, 2646]]
reshaped_input_size = [[683, 1024]]
tf_masks = processor.post_process_masks(
tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf"
)
pt_masks = processor.post_process_masks(
pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt"
)
self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy()))
@is_pt_tf_cross_test
def test_image_processor_equivalence(self):
image_processor = self.get_image_processor()
processor = SamProcessor(image_processor=image_processor)
image_input = self.prepare_image_inputs()
pt_input_feat_extract = image_processor(image_input, return_tensors="pt")["pixel_values"].numpy()
pt_input_processor = processor(images=image_input, return_tensors="pt")["pixel_values"].numpy()
tf_input_feat_extract = image_processor(image_input, return_tensors="tf")["pixel_values"].numpy()
tf_input_processor = processor(images=image_input, return_tensors="tf")["pixel_values"].numpy()
self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor))
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract))
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor))