diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 244b94cebb..497803322f 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow. | RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | | RWKV | ❌ | ❌ | ✅ | ❌ | ❌ | -| SAM | ❌ | ❌ | ✅ | ❌ | ❌ | +| SAM | ❌ | ❌ | ✅ | ✅ | ❌ | | SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ | | SEW | ❌ | ❌ | ✅ | ❌ | ❌ | | SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/sam.mdx b/docs/source/en/model_doc/sam.mdx index 969b7e2b22..ac8aa4a573 100644 --- a/docs/source/en/model_doc/sam.mdx +++ b/docs/source/en/model_doc/sam.mdx @@ -99,3 +99,9 @@ Resources: [[autodoc]] SamModel - forward + + +## TFSamModel + +[[autodoc]] TFSamModel + - call \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c079981490..10bf93633a 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3406,6 +3406,13 @@ else: "TFRoFormerPreTrainedModel", ] ) + _import_structure["models.sam"].extend( + [ + "TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFSamModel", + "TFSamPreTrainedModel", + ] + ) _import_structure["models.segformer"].extend( [ "TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -6657,6 +6664,11 @@ if TYPE_CHECKING: TFRoFormerModel, TFRoFormerPreTrainedModel, ) + from .models.sam import ( + TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, + TFSamModel, + TFSamPreTrainedModel, + ) from .models.segformer import ( TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TFSegformerDecodeHead, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 5d637271f9..bfc29f2dd3 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -76,6 +76,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ("roberta", "TFRobertaModel"), ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), ("roformer", "TFRoFormerModel"), + ("sam", "TFSamModel"), ("segformer", "TFSegformerModel"), ("speech_to_text", "TFSpeech2TextModel"), ("swin", "TFSwinModel"), @@ -426,6 +427,11 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( ("mobilebert", "TFMobileBertForNextSentencePrediction"), ] ) +TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( + [ + ("sam", "TFSamModel"), + ] +) TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES) TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) @@ -476,6 +482,14 @@ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES ) +TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES +) + + +class TFAutoModelForMaskGeneration(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING + class TFAutoModel(_BaseAutoModelClass): _model_mapping = TF_MODEL_MAPPING diff --git a/src/transformers/models/sam/__init__.py b/src/transformers/models/sam/__init__.py index d9bbf5f5ea..e8006e89e0 100644 --- a/src/transformers/models/sam/__init__.py +++ b/src/transformers/models/sam/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_tf_available, + is_torch_available, + is_vision_available, +) _import_structure = { @@ -39,6 +45,17 @@ else: "SamModel", "SamPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_sam"] = [ + "TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFSamModel", + "TFSamPreTrainedModel", + ] try: if not is_vision_available(): raise OptionalDependencyNotAvailable() @@ -66,6 +83,14 @@ if TYPE_CHECKING: else: from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel + try: if not is_vision_available(): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 5385296104..64f3bae222 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -34,7 +34,14 @@ from ...image_utils import ( to_numpy_array, valid_images, ) -from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends +from ...utils import ( + TensorType, + is_tf_available, + is_torch_available, + is_torchvision_available, + logging, + requires_backends, +) if is_torch_available(): @@ -44,6 +51,12 @@ if is_torch_available(): if is_torchvision_available(): from torchvision.ops.boxes import batched_nms +if is_tf_available(): + import tensorflow as tf + from tensorflow.experimental import numpy as tnp + + from ...tf_utils import flatten, shape_list + logger = logging.get_logger(__name__) @@ -372,6 +385,61 @@ class SamImageProcessor(BaseImageProcessor): return encoded_outputs def post_process_masks( + self, + masks, + original_sizes, + reshaped_input_sizes, + mask_threshold=0.0, + binarize=True, + pad_size=None, + return_tensors="pt", + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`): + The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + return_tensors (`str`, *optional*, defaults to `"pt"`): + If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors. + Returns: + (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where + (height, width) is given by original_size. + """ + if return_tensors == "pt": + return self._post_process_masks_pt( + masks=masks, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + mask_threshold=mask_threshold, + binarize=binarize, + pad_size=pad_size, + ) + elif return_tensors == "tf": + return self._post_process_masks_tf( + masks=masks, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + mask_threshold=mask_threshold, + binarize=binarize, + pad_size=pad_size, + ) + else: + raise ValueError("return_tensors must be either 'pt' or 'tf'") + + def _post_process_masks_pt( self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None ): """ @@ -418,21 +486,70 @@ class SamImageProcessor(BaseImageProcessor): return output_masks - def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh): + def _post_process_masks_tf( + self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`tf.Tensor`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`tf.Tensor`): + The original size of the images before resizing for input to the model, in (height, width) format. + reshaped_input_sizes (`tf.Tensor`): + The size of the image input to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + Returns: + (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is + given by original_size. + """ + requires_backends(self, ["tf"]) + pad_size = self.pad_size if pad_size is None else pad_size + target_image_size = (pad_size["height"], pad_size["width"]) + + output_masks = [] + for i, original_size in enumerate(original_sizes): + # tf.image expects NHWC, we transpose the NCHW inputs for it + mask = tf.transpose(masks[i], perm=[0, 2, 3, 1]) + interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear") + interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :] + interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear") + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + # And then we transpose them back at the end + output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2])) + + return output_masks + + def post_process_for_mask_generation( + self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt" + ): """ Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. Args: - all_masks (`List[torch.Tensor]`): + all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`): List of all predicted segmentation masks - all_scores (`List[torch.Tensor]`): + all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`): List of all predicted iou scores - all_boxes (`List[torch.Tensor]`): + all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`): List of all bounding boxes of the predicted masks crops_nms_thresh (`float`): Threshold for NMS (Non Maximum Suppression) algorithm. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. """ - return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh) + if return_tensors == "pt": + return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh) + elif return_tensors == "tf": + return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh) def generate_crop_boxes( self, @@ -443,6 +560,7 @@ class SamImageProcessor(BaseImageProcessor): points_per_crop: Optional[int] = 32, crop_n_points_downscale_factor: Optional[List[int]] = 1, device: Optional["torch.device"] = None, + return_tensors: str = "pt", ): """ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. @@ -464,10 +582,35 @@ class SamImageProcessor(BaseImageProcessor): The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. device (`torch.device`, *optional*, defaults to None): Device to use for the computation. If None, cpu will be used. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. """ - return _generate_crop_boxes( - image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device + crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( + image, + target_size, + crop_n_layers, + overlap_ratio, + points_per_crop, + crop_n_points_downscale_factor, ) + if return_tensors == "pt": + if device is None: + device = torch.device("cpu") + crop_boxes = torch.tensor(crop_boxes, device=device) + points_per_crop = torch.tensor(points_per_crop, device=device) + # cropped_images stays as np + input_labels = torch.tensor(input_labels, device=device) + + elif return_tensors == "tf": + if device is not None: + raise ValueError("device is not a supported argument when return_tensors is tf!") + crop_boxes = tf.convert_to_tensor(crop_boxes) + points_per_crop = tf.convert_to_tensor(points_per_crop) + # cropped_images stays as np + input_labels = tf.convert_to_tensor(input_labels) + else: + raise ValueError("return_tensors must be either 'pt' or 'tf'.") + return crop_boxes, points_per_crop, cropped_images, input_labels def filter_masks( self, @@ -479,6 +622,67 @@ class SamImageProcessor(BaseImageProcessor): stability_score_thresh=0.95, mask_threshold=0, stability_score_offset=1, + return_tensors="pt", + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`Union[torch.Tensor, tf.Tensor]`): + Input masks. + iou_scores (`Union[torch.Tensor, tf.Tensor]`): + List of IoU scores. + original_size (`Tuple[int,int]`): + Size of the orginal image. + cropped_box_image (`np.array`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + return_tensors (`str`, *optional*, defaults to `pt`): + If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`. + """ + if return_tensors == "pt": + return self._filter_masks_pt( + masks=masks, + iou_scores=iou_scores, + original_size=original_size, + cropped_box_image=cropped_box_image, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + mask_threshold=mask_threshold, + stability_score_offset=stability_score_offset, + ) + elif return_tensors == "tf": + return self._filter_masks_tf( + masks=masks, + iou_scores=iou_scores, + original_size=original_size, + cropped_box_image=cropped_box_image, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + mask_threshold=mask_threshold, + stability_score_offset=stability_score_offset, + ) + + def _filter_masks_pt( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, ): """ Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being @@ -525,7 +729,7 @@ class SamImageProcessor(BaseImageProcessor): # compute stability score if stability_score_thresh > 0.0: - stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset) + stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset) keep_mask = keep_mask & (stability_scores > stability_score_thresh) scores = iou_scores[keep_mask] @@ -549,8 +753,85 @@ class SamImageProcessor(BaseImageProcessor): return masks, scores, converted_boxes + def _filter_masks_tf( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. -def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): + Args: + masks (`tf.Tensor`): + Input masks. + iou_scores (`tf.Tensor`): + List of IoU scores. + original_size (`Tuple[int,int]`): + Size of the orginal image. + cropped_box_image (`np.array`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + requires_backends(self, ["tf"]) + original_height, original_width = original_size + iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]]) + masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]]) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the same batch size.") + + batch_size = masks.shape[0] + + keep_mask = tf.ones(batch_size, dtype=tf.bool) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box_tf(masks) + + keep_mask = ~_is_box_near_crop_edge_tf( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppresion + masks = _mask_to_rle_tf(masks) + + return masks, scores, converted_boxes + + +def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int): # One mask is always contained inside the other. # Save memory by preventing unnecesary cast to torch.int64 intersections = ( @@ -561,6 +842,17 @@ def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stabi return stability_scores +def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int): + # Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure + # we get the right division results. + intersections = tf.count_nonzero( + masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32 + ) + unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32) + stability_scores = intersections / unions + return stability_scores + + def _build_point_grid(n_per_side: int) -> np.ndarray: """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" offset = 1 / (2 * n_per_side) @@ -606,7 +898,6 @@ def _generate_crop_boxes( overlap_ratio: float = 512 / 1500, points_per_crop: Optional[int] = 32, crop_n_points_downscale_factor: Optional[List[int]] = 1, - device: Optional["torch.device"] = None, ) -> Tuple[List[List[int]], List[int]]: """ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. @@ -626,11 +917,7 @@ def _generate_crop_boxes( Number of points to sample per crop. crop_n_points_downscale_factor (`int`, *optional*): The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. - device (`torch.device`, *optional*): - Device to run the crop generation on. Defaults to CPU. """ - if device is None: - device = torch.device("cpu") if isinstance(image, list): raise ValueError("Only one image is allowed for crop generation.") @@ -648,12 +935,11 @@ def _generate_crop_boxes( crop_boxes, image, points_grid, layer_idxs, target_size, original_size ) - crop_boxes = torch.tensor(crop_boxes, dtype=torch.float32, device=device) - point_grid_per_crop = np.array([point_grid_per_crop]) - points_per_crop = torch.tensor(point_grid_per_crop, device=device) - points_per_crop = points_per_crop.permute(0, 2, 1, 3) + crop_boxes = crop_boxes.astype(np.float32) + points_per_crop = np.array([point_grid_per_crop]) + points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3)) - input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.long, device=device) + input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64) return crop_boxes, points_per_crop, cropped_images, input_labels @@ -730,6 +1016,16 @@ def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int): return torch.nn.functional.pad(masks, pad, value=0) +def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return tf.pad(masks, pad, constant_values=0) + + def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): """Filter masks at the edge of a crop, but not at the edge of the original image.""" crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) @@ -748,6 +1044,24 @@ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): return torch.any(near_crop_edge, dim=1) +def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32) + orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32) + + left, top, _, _ = crop_box + offset = tf.convert_to_tensor([[left, top, left, top]]) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = tf.expand_dims(offset, 1) + boxes = tf.cast(boxes + offset, tf.float32) + + near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0) + near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0) + near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge) + return tf.reduce_any(near_crop_edge, axis=1) + + def _batched_mask_to_box(masks: "torch.Tensor"): """ Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which @@ -797,6 +1111,54 @@ def _batched_mask_to_box(masks: "torch.Tensor"): return out +def _batched_mask_to_box_tf(masks: "tf.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + + if tf.size(masks) == 0: + return tf.zeros([*masks.shape[:-2], 4]) + + # Normalize shape to Cxheightxwidth + shape = shape_list(masks) + height, width = shape[-2:] + + # Get top and bottom edges + in_height = tf.reduce_max(masks, axis=-1) + in_height_coords = in_height * tf.range(height)[None, :] + bottom_edges = tf.reduce_max(in_height_coords, axis=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges = tf.reduce_min(in_height_coords, axis=-1) + + # Get left and right edges + in_width, _ = tf.reduce_max(masks, axis=-2) + in_width_coords = in_width * tf.range(width)[None, :] + right_edges, _ = tf.reduce_max(in_width_coords, axis=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = tf.reduce_min(in_width_coords, axis=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1) + out = out * tf.expand_dims(~empty_filter, -1) + + # Return to original shape + out = tf.reshape(out, *shape[:-2], 4) + return out + + def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): """ Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. @@ -820,6 +1182,29 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): return out +def _mask_to_rle_tf(input_mask: "tf.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = tf.where(diff) + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]] + out.append({"size": [height, width], "counts": counts}) + return out + + def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: """Compute a binary mask from an uncompressed RLE.""" height, width = rle["size"] @@ -836,7 +1221,7 @@ def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): """ - Perform NMS (Non Maxium Suppression) on the outputs. + Perform NMS (Non Maximum Suppression) on the outputs. Args: rle_masks (`torch.Tensor`): @@ -861,3 +1246,32 @@ def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh= masks = [_rle_to_mask(rle) for rle in rle_masks] return masks, iou_scores, rle_masks, mask_boxes + + +def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maximum Suppression) on the outputs. + + Args: + rle_masks (`tf.Tensor`): + binary masks in the RLE format + iou_scores (`tf.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`tf.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + keep_by_nms = tf.image.combined_non_max_suppression( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=torch.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index bf14a4b241..7df4611750 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput): mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None -# Copied from src.models.modeling_vit_mae.ViTMAEPatchEmbeddings with ViTMAEPatchEmbeddings->SamVisionEmbeddings,x->embeddings class SamPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial @@ -198,7 +197,7 @@ class SamAttention(nn.Module): values. """ - def __init__(self, config, downsample_rate=None) -> None: + def __init__(self, config, downsample_rate=None): super().__init__() self.hidden_size = config.hidden_size @@ -252,7 +251,7 @@ class SamAttention(nn.Module): class SamTwoWayAttentionBlock(nn.Module): - def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False) -> None: + def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False): """ A transformer block with four layers: (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on @@ -476,7 +475,7 @@ class SamMaskDecoder(nn.Module): the embeddings of the mask inputs multimask_output (bool): Whether to return multiple masks or a single mask. - output_attentions (bool, **optional**): + output_attentions (bool, *optional*): Whether or not to return the attentions tensors of all attention layers. """ batch_size, num_channels, height, width = image_embeddings.shape @@ -668,11 +667,11 @@ class SamPromptEncoder(nn.Module): Embeds different types of prompts, returning both sparse and dense embeddings. Args: - points (`torch.Tensor`, **optionnal**): + points (`torch.Tensor`, *optional*): point coordinates and labels to embed. - boxes (`torch.Tensor`, **optionnal**): + boxes (`torch.Tensor`, *optional*): boxes to embed - masks (`torch.Tensor`, **optionnal**): + masks (`torch.Tensor`, *optional*): masks to embed """ sparse_embeddings = None @@ -707,7 +706,7 @@ class SamPromptEncoder(nn.Module): class SamVisionAttention(nn.Module): """Multi-head Attention block with relative position embeddings.""" - def __init__(self, config, window_size) -> None: + def __init__(self, config, window_size): super().__init__() input_size = ( (config.image_size // config.patch_size, config.image_size // config.patch_size) @@ -845,7 +844,7 @@ class SamVisionAttention(nn.Module): class SamVisionLayer(nn.Module): - def __init__(self, config, window_size) -> None: + def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attn = SamVisionAttention(config, window_size) @@ -1166,7 +1165,7 @@ SAM_INPUTS_DOCSTRING = r""" class SamModel(SamPreTrainedModel): _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] - def __init__(self, config) -> None: + def __init__(self, config): super().__init__(config) self.shared_image_embedding = SamPositionalEmbedding(config.vision_config) @@ -1334,7 +1333,6 @@ class SamModel(SamPreTrainedModel): image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) vision_attentions = None - mask_decoder_attentions = None vision_hidden_states = None if pixel_values is not None: @@ -1359,7 +1357,8 @@ class SamModel(SamPreTrainedModel): "The batch size of the image embeddings and the input points must be the same. ", "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), " if you want to pass multiple points for the same image, make sure that you passed ", - " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", + " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", ) sparse_embeddings, dense_embeddings = self.prompt_encoder( diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py new file mode 100644 index 0000000000..ddd8e526a7 --- /dev/null +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -0,0 +1,1497 @@ +# coding=utf-8 +# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a +discrepancy, the original file should be regarded as the 'reference' version. +""" + +import collections +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...modeling_tf_outputs import TFBaseModelOutput +from ...modeling_tf_utils import TFPreTrainedModel, shape_list, unpack_inputs +from ...tf_utils import flatten, functional_layernorm +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SamConfig" +_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge" + +TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/sam-vit-huge", + "facebook/sam-vit-large", + "facebook/sam-vit-base", + # See all SAM models at https://huggingface.co/models?filter=sam +] + + +@dataclass +class TFSamVisionEncoderOutput(ModelOutput): + """ + Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + + Args: + image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[tf.Tensor] = None + last_hidden_state: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + + +@dataclass +class TFSamImageSegmentationOutput(ModelOutput): + """ + Base class for Segment-Anything model's output + + Args: + iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`): + The iou scores of the predicted masks. + pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`): + The predicted low resolutions masks. Needs to be post-processed by the processor + vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. + vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + iou_scores: tf.Tensor = None + pred_masks: tf.Tensor = None + vision_hidden_states: Optional[Tuple[tf.Tensor]] = None + vision_attentions: Optional[Tuple[tf.Tensor]] = None + mask_decoder_attentions: Optional[Tuple[tf.Tensor]] = None + + +class TFSamPatchEmbeddings(tf.keras.layers.Layer): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = tf.keras.layers.Conv2D( + hidden_size, kernel_size=patch_size, strides=patch_size, name="projection" + ) + + def call(self, pixel_values): + batch_size, num_channels, height, width = shape_list(pixel_values) + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1])) + return embeddings + + +class TFSamMLPBlock(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.lin1 = tf.keras.layers.Dense(config.mlp_dim, name="lin1") + self.lin2 = tf.keras.layers.Dense(config.hidden_size, name="lin2") + self.act = ACT2FN[config.hidden_act] + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +class TFSamLayerNorm(tf.keras.layers.Layer): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(**kwargs) + self.eps = eps + self.data_format = data_format + self.normalized_shape = normalized_shape + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + + def build(self, input_shape): + self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight") + self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias") + super().build(input_shape) + + def call(self, x: tf.Tensor) -> tf.Tensor: + if self.data_format == "channels_last": + x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1) + elif self.data_format == "channels_first": + x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1) + return x + + +class TFSamAttention(tf.keras.layers.Layer): + """ + SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None, **kwargs): + super().__init__(**kwargs) + self.hidden_size = config.hidden_size + + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + if self.internal_dim % config.num_attention_heads != 0: + raise ValueError("num_attention_heads must divide hidden_size.") + + self.q_proj = tf.keras.layers.Dense(self.internal_dim, name="q_proj") + self.k_proj = tf.keras.layers.Dense(self.internal_dim, name="k_proj") + self.v_proj = tf.keras.layers.Dense(self.internal_dim, name="v_proj") + self.out_proj = tf.keras.layers.Dense(self.hidden_size, name="out_proj") + + def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor: + batch, point_batch_size, n_tokens, channel = shape_list(hidden_states) + c_per_head = channel // num_attention_heads + hidden_states = tf.reshape( + hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + ) + return tf.transpose(hidden_states, perm=[0, 2, 1, 3]) + + def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor: + batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states) + hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3]) + return tf.reshape( + hidden_states, (batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head) + ) + + def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = shape_list(query)[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # SamAttention + _, _, _, c_per_head = shape_list(query) + attn = tf.matmul( + query, tf.transpose(key, perm=[0, 1, 3, 2]) + ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens + attn = attn / tf.math.sqrt(float(c_per_head)) + attn = tf.nn.softmax(attn, axis=-1) + + # Get output + out = tf.matmul(attn, value) + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + +class TFSamTwoWayAttentionBlock(tf.keras.layers.Layer): + def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`SamMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__(**kwargs) + + self.hidden_size = config.hidden_size + self.layer_norm_eps = config.layer_norm_eps + + self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn") + self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1") + + self.cross_attn_token_to_image = TFSamAttention( + config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image" + ) + self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2") + + self.mlp = TFSamMLPBlock(config, name="mlp") + self.layer_norm3 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3") + + self.layer_norm4 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4") + self.cross_attn_image_to_token = TFSamAttention( + config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token" + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def call( + self, + queries: tf.Tensor, + keys: tf.Tensor, + query_point_embedding: tf.Tensor, + key_point_embedding: tf.Tensor, + output_attentions: bool = False, + ): + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + + outputs = (queries, keys) + + if output_attentions: + outputs = outputs + (attn_out,) + else: + outputs = outputs + (None,) + + return outputs + + +class TFSamTwoWayTransformer(tf.keras.layers.Layer): + def __init__(self, config: SamMaskDecoderConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = [] + + for i in range(self.num_hidden_layers): + self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}")) + + self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image") + self.layer_norm_final_attn = tf.keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="layer_norm_final_attn" + ) + + def call( + self, + point_embeddings: tf.Tensor, + image_embeddings: tf.Tensor, + image_positional_embeddings: tf.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TFBaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + all_attentions = () + + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None] + image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None] + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys, attention_outputs = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (attention_outputs,) + + # Apply the final attenion layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys, all_attentions + + +class TFSamFeedForward(tf.keras.layers.Layer): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs + ): + super().__init__(**kwargs) + self.num_layers = num_layers + self.activation = tf.keras.layers.ReLU() + self.proj_in = tf.keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in") + self.proj_out = tf.keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out") + self.layers = [ + tf.keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}") + for i in range(num_layers - 2) + ] + self.sigmoid_output = sigmoid_output + + def call(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = tf.sigmoid(hidden_states) + return hidden_states + + +class TFSamMaskDecoder(tf.keras.layers.Layer): + def __init__(self, config: SamMaskDecoderConfig, **kwargs): + super().__init__(**kwargs) + + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.transformer = TFSamTwoWayTransformer(config, name="transformer") + + self.upscale_conv1 = tf.keras.layers.Conv2DTranspose( + self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first" + ) + self.upscale_conv2 = tf.keras.layers.Conv2DTranspose( + self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first" + ) + self.upscale_layer_norm = TFSamLayerNorm( + self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm" + ) + self.activation = tf.nn.gelu + + mlps_list = [] + for i in range(self.num_mask_tokens): + mlps_list += [ + TFSamFeedForward( + self.hidden_size, + self.hidden_size, + self.hidden_size // 8, + 3, + name=f"output_hypernetworks_mlps_._{i}", + ) + ] + self.output_hypernetworks_mlps = mlps_list + + self.iou_prediction_head = TFSamFeedForward( + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + name="iou_prediction_head", + ) + + def build(self, input_shape): + self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True) + self.mask_tokens = self.add_weight( + shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True + ) + super().build(input_shape) + + def call( + self, + image_embeddings: tf.Tensor, + image_positional_embeddings: tf.Tensor, + sparse_prompt_embeddings: tf.Tensor, + dense_prompt_embeddings: tf.Tensor, + multimask_output: bool, + output_attentions: Optional[bool] = None, + ) -> Tuple[tf.Tensor, tf.Tensor]: + batch_size, num_channels, height, width = shape_list(image_embeddings) + point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1]) + + output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32) + output_tokens = tf.tile( + output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1] + ) # Should be (batch_size, point_size, 5, 32) + + # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only + # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced + # it with an explicit shape check to avoid data-dependent control flow which breaks XLA. + if sparse_prompt_embeddings.shape[1] != 0: + tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2) + else: + tokens = output_tokens + point_embeddings = tf.cast(tokens, self.iou_token.dtype) + + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = tf.tile(image_embeddings, [point_batch_size, 1, 1, 1]) + image_positional_embeddings = tf.tile(image_positional_embeddings, [point_batch_size, 1, 1, 1]) + + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2)) + image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width]) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + hyper_in_list = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = tf.stack(hyper_in_list, axis=2) + + _, num_channels, height, width = shape_list(upscaled_embedding) + upscaled_embedding = tf.reshape( + upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width] + ) + masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width]) + + iou_pred = self.iou_prediction_head(iou_token_out) + + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + outputs = (masks, iou_pred) + + if output_attentions: + outputs = outputs + (attentions,) + else: + outputs = outputs + (None,) + + return outputs + + +class TFSamPositionalEmbedding(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.scale = config.hidden_size // 2 + self.config = config + + def build(self, input_shape): + # TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized? + self.positional_embedding = self.add_weight( + name="positional_embedding", + shape=(2, self.config.num_pos_feats), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale), + trainable=False, + ) + + def call(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = tf.identity(input_coords) + + if input_shape is not None: + coordinates = tf.stack( + [ + tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1], + tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0], + ], + axis=-1, + ) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = tf.cast(coordinates, self.positional_embedding.dtype) + coordinates = tf.matmul(coordinates, self.positional_embedding) + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1) + + +class TFSamMaskEmbedding(tf.keras.layers.Layer): + def __init__(self, config: SamPromptEncoderConfig, **kwargs): + super().__init__(**kwargs) + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1") + self.conv2 = tf.keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2") + self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3") + self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") + self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") + + def call(self, masks): + masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first + return dense_embeddings + + def build(self, input_shape): + # This class needs an explicit build method because it isn't called with the standard dummy inputs + conv1_shape = [None, None, None, 1] + conv2_shape = [None, None, None, self.mask_input_channels] + conv3_shape = [None, None, None, self.mask_input_channels * 4] + layer_norm1_shape = [None, None, None, self.mask_input_channels] + layer_norm2_shape = [None, None, None, self.mask_input_channels * 4] + with tf.name_scope("conv1"): + self.conv1.build(conv1_shape) + with tf.name_scope("conv2"): + self.conv2.build(conv2_shape) + with tf.name_scope("conv3"): + self.conv3.build(conv3_shape) + with tf.name_scope("layer_norm1"): + self.layer_norm1.build(layer_norm1_shape) + with tf.name_scope("layer_norm2"): + self.layer_norm2.build(layer_norm2_shape) + super().build(input_shape) + + +class TFSamPromptEncoder(tf.keras.layers.Layer): + def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs): + super().__init__(**kwargs) + self.shared_embedding = shared_patch_embedding + self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed") + self.no_mask_embed = None + + self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) + self.input_image_size = config.image_size + + self.point_embed = [] + self.hidden_size = config.hidden_size + self.not_a_point_embed = None + self.config = config + + def build(self, input_shape): + self.no_mask_embed = self.add_weight( + name="no_mask_embed.weight", + shape=(1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + self.point_embed = [ + self.add_weight( + name=f"point_embed_._{i}.weight", + shape=(1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + for i in range(self.config.num_point_embeddings) + ] + self.not_a_point_embed = self.add_weight( + name="not_a_point_embed.weight", + shape=(1, self.hidden_size), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), + trainable=True, + ) + with tf.name_scope("mask_embed"): + # We must explicitly build the mask embed because it isn't touched by the standard dummy inputs + self.mask_embed.build( + (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size) + ) + super().build(input_shape) + + def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1]) + target_labels_shape = (points.shape[0], points.shape[1], 1) + padding_point = tf.zeros(target_point_shape, dtype=points.dtype) + padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype) + points = tf.concat([points, padding_point], axis=2) + labels = tf.concat([labels, padding_label], axis=2) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding) + + point_embedding = tf.where( + labels[..., None] != -10, + point_embedding, + tf.zeros_like(point_embedding), + ) + point_embedding = tf.where( + (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding + ) + point_embedding = tf.where( + (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding + ) + return point_embedding + + def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + batch_size, nb_boxes = boxes.shape[:2] + coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2)) + input_shape = (self.input_image_size, self.input_image_size) + corner_embedding = self.shared_embedding(coords, input_shape) + corner_embedding += tf.where( + tf.range(corner_embedding.shape[2])[None, None, :, None] == 0, + self.point_embed[2][0], + self.point_embed[3][0], + ) + return corner_embedding + + def call( + self, + batch_size: Optional[int], + input_points: Optional[Tuple[tf.Tensor, tf.Tensor]], + input_labels: Optional[tf.Tensor], + input_boxes: Optional[tf.Tensor], + input_masks: Optional[tf.Tensor], + ) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`tf.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`tf.Tensor`, *optional*): + boxes to embed + masks (`tf.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + if input_points is not None: + batch_size, point_batch_size = input_points.shape[:2] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = tf.zeros( + (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype + ) + sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2) + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed[0] + dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1)) + dense_embeddings = tf.tile( + dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1]) + ) + if sparse_embeddings is None: + sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype) + + return sparse_embeddings, dense_embeddings + + +class TFSamVisionAttention(tf.keras.layers.Layer): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size, **kwargs): + super().__init__(**kwargs) + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + self.input_size = input_size + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.head_dim = head_dim + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = tf.keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv") + self.proj = tf.keras.layers.Dense(config.hidden_size, name="proj") + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + def build(self, input_shape): + if self.input_size is not None: + # initialize relative positional embeddings + self.rel_pos_h = self.add_weight( + shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h" + ) + self.rel_pos_w = self.add_weight( + shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w" + ) + super().build(input_shape) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`tf.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = tf.image.resize( + tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)), + size=(max_rel_dist, rel_pos.shape[1]), + method="bilinear", + ) + rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist)) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0) + k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32)) + + def add_decomposed_rel_pos( + self, + attn: tf.Tensor, + query: tf.Tensor, + rel_pos_h: tf.Tensor, + rel_pos_w: tf.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> tf.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (`tf.Tensor`): + attention map. + query (`tf.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`tf.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`tf.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + attn (`tf.Tensor`): + attention map with added relative positional embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = shape_list(query) + reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim)) + rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width)) + attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2) + attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width)) + return attn + + def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor: + batch_size, height, width, _ = shape_list(hidden_states) + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1)) + qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4)) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = tf.unstack( + tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0 + ) + attn_weights = tf.matmul(query * self.scale, key, transpose_b=True) + + if self.use_rel_pos: + attn_weights = self.add_decomposed_rel_pos( + attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + + attn_weights = tf.nn.softmax(attn_weights, axis=-1) + + if training: + attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) + else: + attn_probs = attn_weights + + attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1)) + attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4)) + attn_output = tf.reshape(attn_output, (batch_size, height, width, -1)) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + +class TFSamVisionLayer(tf.keras.layers.Layer): + def __init__(self, config, window_size, **kwargs): + super().__init__(**kwargs) + self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") + self.attn = TFSamVisionAttention(config, window_size, name="attn") + self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") + self.mlp = TFSamMLPBlock(config, name="mlp") + self.window_size = window_size + + def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]: + batch_size, height, width, channel = shape_list(hidden_states) + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + if pad_h > 0 or pad_w > 0: + hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]]) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = tf.reshape( + hidden_states, + [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel], + ) + windows = tf.reshape( + tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel] + ) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int] + ) -> tf.Tensor: + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = tf.reshape( + windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1] + ) + hidden_states = tf.reshape( + tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1] + ) + + if pad_height > height or pad_width > width: + hidden_states = hidden_states[:, :height, :width, :] + return hidden_states + + def call( + self, + hidden_states: tf.Tensor, + output_attentions: Optional[bool] = False, + training: Optional[bool] = False, + ) -> Tuple[tf.Tensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + training=training, + ) + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class TFSamVisionNeck(tf.keras.layers.Layer): + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + + self.conv1 = tf.keras.layers.Conv2D( + config.output_channels, + kernel_size=1, + use_bias=False, + name="conv1", + ) + self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1") + self.conv2 = tf.keras.layers.Conv2D( + config.output_channels, + kernel_size=3, + padding="same", + use_bias=False, + name="conv2", + ) + self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2") + + def call(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) + return hidden_states + + +class TFSamVisionEncoder(tf.keras.layers.Layer): + def __init__(self, config: SamVisionConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.image_size = config.image_size + + self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed") + + self.pos_embed = None + + self.layers = [] + for i in range(config.num_hidden_layers): + layer = TFSamVisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + name=f"layers_._{i}", + ) + self.layers.append(layer) + + self.neck = TFSamVisionNeck(config, name="neck") + + def build(self, input_shape): + if self.config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = self.add_weight( + shape=[ + 1, + self.config.image_size // self.config.patch_size, + self.config.image_size // self.config.patch_size, + self.config.hidden_size, + ], + initializer="zeros", + trainable=True, + name="pos_embed", + ) + super().build(input_shape) + + def get_input_embeddings(self): + return self.patch_embed + + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFSamVisionEncoderOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.neck(hidden_states) + + if not return_dict: + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_self_attentions,) + return outputs + + return TFSamVisionEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class TFSamPreTrainedModel(TFPreTrainedModel): + config_class = SamConfig + base_model_prefix = "sam" + main_input_name = "pixel_values" + + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + VISION_DUMMY_INPUTS = tf.random.uniform( + shape=( + 3, + self.config.vision_config.num_channels, + self.config.vision_config.image_size, + self.config.vision_config.image_size, + ), + dtype=tf.float32, + ) + return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} + + @tf.function( + input_signature=[ + { + "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"), + } + ] + ) + def serving(self, inputs): + """ + Method used for serving the model. + + Args: + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + output = self.call(inputs) + + return self.serving_output(output) + + +SAM_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a TensorFlow [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) + subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to + general usage and behavior. + + Parameters: + config ([`SamConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SAM_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for + details. + input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second + dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per + input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size, + the number of boxes per image and the coordinates of the top left and botton right point of the box. In the + order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `call` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", + " optional 2D location and bounding boxes.", + SAM_START_DOCSTRING, +) +class TFSamModel(TFSamPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding") + + self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder") + self.prompt_encoder = TFSamPromptEncoder( + config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder" + ) + self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder") + self.config = config + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self): + size = self.config.prompt_encoder_config.image_embedding_size + grid = tf.ones((size, size)) + y_embed = tf.math.cumsum(grid, axis=0) - 0.5 + x_embed = tf.math.cumsum(grid, axis=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1)) + return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width + + def get_image_embeddings( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple. + + """ + vision_output = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_output[0] + return image_embeddings + + def get_prompt_embeddings( + self, + input_points: Optional[tf.Tensor] = None, + input_labels: Optional[tf.Tensor] = None, + input_boxes: Optional[tf.Tensor] = None, + input_masks: Optional[tf.Tensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @unpack_inputs + @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + input_points: Optional[tf.Tensor] = None, + input_labels: Optional[tf.Tensor] = None, + input_boxes: Optional[tf.Tensor] = None, + input_masks: Optional[tf.Tensor] = None, + image_embeddings: Optional[tf.Tensor] = None, + multimask_output: bool = True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict=None, + training=False, + **kwargs, + ) -> List[Dict[str, tf.Tensor]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", + " got {}.".format(input_points.shape), + ) + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", + " got {}.".format(input_boxes.shape), + ) + if input_points is not None and input_boxes is not None: + point_batch_size = input_points.shape[1] + box_batch_size = input_boxes.shape[1] + if point_batch_size != box_batch_size: + raise ValueError( + "You should provide as many bounding boxes as input points per box. Got {} and {}.".format( + point_batch_size, box_batch_size + ) + ) + if pixel_values is not None: + # Ensures that later checks pass even with an all-None shape from the serving signature + pixel_values = tf.ensure_shape( + pixel_values, + [ + None, + self.config.vision_config.num_channels, + self.config.vision_config.image_size, + self.config.vision_config.image_size, + ], + ) + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0] + image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + training=training, + ) + image_embeddings = vision_outputs["last_hidden_state"] + + if output_hidden_states: + vision_hidden_states = vision_outputs["hidden_states"] + if output_attentions: + vision_attentions = vision_outputs["attentions"] + + if input_points is not None and input_labels is None: + input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32) + + if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: + raise ValueError( + "The batch size of the image embeddings and the input points must be the same. ", + "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), + " if you want to pass multiple points for the same image, make sure that you passed ", + " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", + " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", + ) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + batch_size=shape_list(image_embeddings)[0], + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + output_attentions=output_attentions, + ) + + if not return_dict: + output = (iou_predictions, low_res_masks) + if output_hidden_states: + output = output + (vision_hidden_states,) + + if output_attentions: + output = output + (vision_attentions, mask_decoder_attentions) + return output + + return TFSamImageSegmentationOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + mask_decoder_attentions=mask_decoder_attentions, + ) + + def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput: + hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None + + return TFSamImageSegmentationOutput( + iou_scores=output.iou_scores, + pred_masks=output.pred_masks, + vision_hidden_states=hs if self.config.output_hidden_states else None, + vision_attentions=attns if self.config.output_attentions else None, + mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None, + ) diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index 919ed685a8..fd73260f94 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -22,12 +22,15 @@ import numpy as np from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding -from ...utils import TensorType, is_torch_available +from ...utils import TensorType, is_tf_available, is_torch_available if is_torch_available(): import torch +if is_tf_available(): + import tensorflow as tf + class SamProcessor(ProcessorMixin): r""" @@ -72,7 +75,7 @@ class SamProcessor(ProcessorMixin): # pop arguments that are not used in the foward but used nevertheless original_sizes = encoding_image_processor["original_sizes"] - if isinstance(original_sizes, torch.Tensor): + if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor original_sizes = original_sizes.numpy() input_points, input_labels, input_boxes = self._check_and_preprocess_points( @@ -139,18 +142,30 @@ class SamProcessor(ProcessorMixin): input_boxes = torch.from_numpy(input_boxes) # boxes batch size of 1 by default input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes + elif return_tensors == "tf": + input_boxes = tf.convert_to_tensor(input_boxes) + # boxes batch size of 1 by default + input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes encoding_image_processor.update({"input_boxes": input_boxes}) if input_points is not None: if return_tensors == "pt": input_points = torch.from_numpy(input_points) # point batch size of 1 by default input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points + elif return_tensors == "tf": + input_points = tf.convert_to_tensor(input_points) + # point batch size of 1 by default + input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points encoding_image_processor.update({"input_points": input_points}) if input_labels is not None: if return_tensors == "pt": input_labels = torch.from_numpy(input_labels) # point batch size of 1 by default input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels + elif return_tensors == "tf": + input_labels = tf.convert_to_tensor(input_labels) + # point batch size of 1 by default + input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels encoding_image_processor.update({"input_labels": input_labels}) return encoding_image_processor @@ -204,7 +219,7 @@ class SamProcessor(ProcessorMixin): it is converted to a `numpy.ndarray` and then to a `list`. """ if input_points is not None: - if isinstance(input_points, torch.Tensor): + if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor input_points = input_points.numpy().tolist() if not isinstance(input_points, list) or not isinstance(input_points[0], list): @@ -214,7 +229,7 @@ class SamProcessor(ProcessorMixin): input_points = None if input_labels is not None: - if isinstance(input_labels, torch.Tensor): + if hasattr(input_labels, "numpy"): input_labels = input_labels.numpy().tolist() if not isinstance(input_labels, list) or not isinstance(input_labels[0], list): @@ -224,7 +239,7 @@ class SamProcessor(ProcessorMixin): input_labels = None if input_boxes is not None: - if isinstance(input_boxes, torch.Tensor): + if hasattr(input_boxes, "numpy"): input_boxes = input_boxes.numpy().tolist() if ( diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py index 2d4fa6fda9..306f73c0b1 100644 --- a/src/transformers/tf_utils.py +++ b/src/transformers/tf_utils.py @@ -70,6 +70,56 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name) +def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1): + # This is a very simplified functional layernorm, designed to duplicate + # the functionality of PyTorch nn.functional.layer_norm when this is needed to port + # models in Transformers. + + if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int): + raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.") + + # Get mean and variance on the axis to be normalized + mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True) + + if axis != -1: + # Reshape scale and weight to have the same rank as inputs, but with 1 dimensions + # on every dimension except axis + shape = [1] * inputs.shape.rank + shape[axis] = shape_list(inputs)[axis] + weight = tf.reshape(weight, shape) + bias = tf.reshape(bias, shape) + + # Compute layer normalization using the batch_normalization + # function. + outputs = tf.nn.batch_normalization( + inputs, + mean, + variance, + offset=bias, + scale=weight, + variance_epsilon=epsilon, + ) + return outputs + + +def flatten(input, start_dim=0, end_dim=-1): + # Replicates the behavior of torch.flatten in TF + + # If end_dim or start_dim is negative, count them from the end + if end_dim < 0: + end_dim += input.shape.rank + if start_dim < 0: + start_dim += input.shape.rank + + if start_dim == end_dim: + return input + + in_shape = tf.shape(input) + flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1]) + out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0) + return tf.reshape(input, out_shape) + + def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor: """ Invert an attention mask (e.g., switches 0. and 1.). diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 2a043f50f3..658d7f689f 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -2317,6 +2317,23 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["tf"]) +TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TFSamModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFSamPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index a45e5c4bf5..8507c0a638 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -436,6 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_hidden_states_output(self): pass + def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4): + super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol) + @slow def test_model_from_pretrained(self): for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -470,8 +473,10 @@ class SamModelIntegrationTest(unittest.TestCase): with torch.no_grad(): outputs = model(**inputs) scores = outputs.iou_scores.squeeze() + masks = outputs.pred_masks[0, 0, 0, 0, :3] - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4)) + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4)) + self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4)) def test_inference_mask_generation_one_point_one_bb(self): model = SamModel.from_pretrained("facebook/sam-vit-huge") @@ -491,8 +496,12 @@ class SamModelIntegrationTest(unittest.TestCase): with torch.no_grad(): outputs = model(**inputs) scores = outputs.iou_scores.squeeze() + masks = outputs.pred_masks[0, 0, 0, 0, :3] - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4)) + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4)) + self.assertTrue( + torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4) + ) def test_inference_mask_generation_batched_points_batched_images(self): model = SamModel.from_pretrained("facebook/sam-vit-huge") @@ -514,6 +523,7 @@ class SamModelIntegrationTest(unittest.TestCase): with torch.no_grad(): outputs = model(**inputs) scores = outputs.iou_scores.squeeze().cpu() + masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu() EXPECTED_SCORES = torch.tensor( [ @@ -531,7 +541,9 @@ class SamModelIntegrationTest(unittest.TestCase): ], ] ) + EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406]) self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) + self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3)) def test_inference_mask_generation_one_point_one_bb_zero(self): model = SamModel.from_pretrained("facebook/sam-vit-huge") diff --git a/tests/models/sam/test_modeling_tf_sam.py b/tests/models/sam/test_modeling_tf_sam.py new file mode 100644 index 0000000000..a07398365f --- /dev/null +++ b/tests/models/sam/test_modeling_tf_sam.py @@ -0,0 +1,671 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the TensorFlow SAM model. """ + + +import inspect +import unittest + +import numpy as np +import requests + +from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig +from transformers.testing_utils import require_tf, slow +from transformers.utils import is_tf_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_tf_available(): + import tensorflow as tf + + from transformers import SamProcessor, TFSamModel + from transformers.models.sam.modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST + +if is_vision_available(): + from PIL import Image + + +class TFSamPromptEncoderTester: + def __init__( + self, + hidden_size=32, + input_image_size=24, + patch_size=2, + mask_input_channels=4, + num_point_embeddings=4, + hidden_act="gelu", + ): + self.hidden_size = hidden_size + self.input_image_size = input_image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + + def get_config(self): + return SamPromptEncoderConfig( + image_size=self.input_image_size, + patch_size=self.patch_size, + mask_input_channels=self.mask_input_channels, + hidden_size=self.hidden_size, + num_point_embeddings=self.num_point_embeddings, + hidden_act=self.hidden_act, + ) + + def prepare_config_and_inputs(self): + dummy_points = floats_tensor([self.batch_size, 3, 2]) + config = self.get_config() + + return config, dummy_points + + +class TFSamMaskDecoderTester: + def __init__( + self, + hidden_size=32, + hidden_act="relu", + mlp_dim=64, + num_hidden_layers=2, + num_attention_heads=4, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=32, + layer_norm_eps=1e-6, + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downsample_rate = attention_downsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + + def get_config(self): + return SamMaskDecoderConfig( + hidden_size=self.hidden_size, + hidden_act=self.hidden_act, + mlp_dim=self.mlp_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + attention_downsample_rate=self.attention_downsample_rate, + num_multimask_outputs=self.num_multimask_outputs, + iou_head_depth=self.iou_head_depth, + iou_head_hidden_dim=self.iou_head_hidden_dim, + layer_norm_eps=self.layer_norm_eps, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + + dummy_inputs = { + "image_embedding": floats_tensor([self.batch_size, self.hidden_size]), + } + + return config, dummy_inputs + + +class TFSamModelTester: + def __init__( + self, + parent, + hidden_size=36, + intermediate_size=72, + projection_dim=62, + output_channels=32, + num_hidden_layers=2, + num_attention_heads=4, + num_channels=3, + image_size=24, + patch_size=2, + hidden_act="gelu", + layer_norm_eps=1e-06, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + rel_pos_zero_init=False, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=16, + mlp_dim=None, + batch_size=2, + ): + self.parent = parent + self.image_size = image_size + self.patch_size = patch_size + self.output_channels = output_channels + self.num_channels = num_channels + self.hidden_size = hidden_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.rel_pos_zero_init = rel_pos_zero_init + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = mlp_dim + self.batch_size = batch_size + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + self.prompt_encoder_tester = TFSamPromptEncoderTester() + self.mask_decoder_tester = TFSamMaskDecoderTester() + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + vision_config = SamVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + initializer_factor=self.initializer_factor, + output_channels=self.output_channels, + qkv_bias=self.qkv_bias, + mlp_ratio=self.mlp_ratio, + use_abs_pos=self.use_abs_pos, + use_rel_pos=self.use_rel_pos, + rel_pos_zero_init=self.rel_pos_zero_init, + window_size=self.window_size, + global_attn_indexes=self.global_attn_indexes, + num_pos_feats=self.num_pos_feats, + mlp_dim=self.mlp_dim, + ) + + prompt_encoder_config = self.prompt_encoder_tester.get_config() + + mask_decoder_config = self.mask_decoder_tester.get_config() + + return SamConfig( + vision_config=vision_config, + prompt_encoder_config=prompt_encoder_config, + mask_decoder_config=mask_decoder_config, + ) + + def create_and_check_model(self, config, pixel_values): + model = TFSamModel(config=config) + result = model(pixel_values) + self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3)) + self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3)) + + def create_and_check_get_image_features(self, config, pixel_values): + model = TFSamModel(config=config) + result = model.get_image_embeddings(pixel_values) + self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12)) + + def create_and_check_get_image_hidden_states(self, config, pixel_values): + model = TFSamModel(config=config) + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=True, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape) + + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=False, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_tf +class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (TFSamModel,) if is_tf_available() else () + pipeline_model_mapping = ( + {"feature-extraction": TFSamModel, "mask-generation": TFSamModel} if is_tf_available() else {} + ) + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_onnx = False + + # TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working + def is_pipeline_test_to_skip( + self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name + ): + return True + + def setUp(self): + self.model_tester = TFSamModelTester(self) + self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False) + self.prompt_encoder_config_tester = ConfigTester( + self, + config_class=SamPromptEncoderConfig, + has_text_modality=False, + num_attention_heads=12, + num_hidden_layers=2, + ) + self.mask_decoder_config_tester = ConfigTester( + self, config_class=SamMaskDecoderConfig, has_text_modality=False + ) + + def test_config(self): + self.vision_config_tester.run_common_tests() + self.prompt_encoder_config_tester.run_common_tests() + self.mask_decoder_config_tester.run_common_tests() + + @unittest.skip(reason="SAM's vision encoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.call) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_get_image_features(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_features(*config_and_inputs) + + def test_image_hidden_states(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + expected_vision_attention_shape = ( + self.model_tester.batch_size * self.model_tester.num_attention_heads, + 196, + 196, + ) + expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + self.assertListEqual( + list(vision_attentions[0].shape[-4:]), + list(expected_vision_attention_shape), + ) + + self.assertListEqual( + list(mask_decoder_attentions[0].shape[-4:]), + list(expected_mask_decoder_attention_shape), + ) + + @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests") + def test_hidden_states_output(self): + pass + + @slow + def test_model_from_pretrained(self): + for model_name in TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = TFSamModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-4, name="outputs", attributes=None): + super().check_pt_tf_outputs( + tf_outputs=tf_outputs, + pt_outputs=pt_outputs, + model_class=model_class, + tol=tol, + name=name, + attributes=attributes, + ) + + +def prepare_image(): + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_dog_img(): + img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +@slow +class SamModelIntegrationTest(unittest.TestCase): + def test_inference_mask_generation_no_point(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + raw_image = prepare_image() + inputs = processor(images=raw_image, return_tensors="tf") + + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) + masks = outputs.pred_masks[0, 0, 0, 0, :3] + + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=2e-4)) + self.assertTrue(np.allclose(masks.numpy(), np.array([-6.6381, -6.0734, -7.5308]), atol=1e-2)) + + def test_inference_mask_generation_one_point_one_bb(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + raw_image = prepare_image() + input_boxes = [[[650, 900, 1000, 1250]]] + input_points = [[[820, 1080]]] + + inputs = processor(images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="tf") + + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) + masks = outputs.pred_masks[0, 0, 0, 0, :3] + + self.assertTrue(np.allclose(scores[-1], np.array(0.9935), atol=2e-4)) + self.assertTrue(np.allclose(masks.numpy(), np.array([-21.5465, -23.1122, -22.3331]), atol=2e-2)) + + def test_inference_mask_generation_batched_points_batched_images(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + raw_image = prepare_image() + input_points = [ + [[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], + [[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], + ] + + inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="tf") + + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) + masks = outputs.pred_masks[0, 0, 0, 0, :3] + + EXPECTED_SCORES = np.array( + [ + [ + [0.9673, 0.9441, 0.9084], + [0.9673, 0.9441, 0.9084], + [0.9673, 0.9441, 0.9084], + [0.9673, 0.9441, 0.9084], + ], + [ + [0.8405, 0.6292, 0.3840], + [0.9673, 0.9441, 0.9084], + [0.9673, 0.9441, 0.9084], + [0.9673, 0.9441, 0.9084], + ], + ] + ) + EXPECTED_MASKS = np.array([-26.5424, -34.0901, -30.6406]) + self.assertTrue(np.allclose(scores.numpy(), EXPECTED_SCORES, atol=1e-3)) + self.assertTrue(np.allclose(masks.numpy(), EXPECTED_MASKS, atol=3e-2)) + + def test_inference_mask_generation_one_point_one_bb_zero(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + raw_image = prepare_image() + input_boxes = [[[620, 900, 1000, 1255]]] + input_points = [[[820, 1080]]] + labels = [[0]] + + inputs = processor( + images=raw_image, + input_boxes=input_boxes, + input_points=input_points, + input_labels=labels, + return_tensors="tf", + ) + + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) + + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9689), atol=1e-4)) + + def test_inference_mask_generation_one_point(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + raw_image = prepare_image() + + input_points = [[[400, 650]]] + input_labels = [[1]] + + inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf") + + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) + + self.assertTrue(np.allclose(scores[-1], np.array(0.9712), atol=1e-4)) + + # With no label + input_points = [[[400, 650]]] + + inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf") + + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) + + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9712), atol=1e-4)) + + def test_inference_mask_generation_two_points(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]]] + input_labels = [[1, 1]] + + inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf") + + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) + + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4)) + + # no labels + inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf") + + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) + + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4)) + + def test_inference_mask_generation_two_points_batched(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]], [[400, 650]]] + input_labels = [[1, 1], [1]] + + inputs = processor( + images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="tf" + ) + + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) + + self.assertTrue(np.allclose(scores[0][-1].numpy(), np.array(0.9936), atol=1e-4)) + self.assertTrue(np.allclose(scores[1][-1], np.array(0.9716), atol=1e-4)) + + def test_inference_mask_generation_one_box(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + raw_image = prepare_image() + + input_boxes = [[[75, 275, 1725, 850]]] + + inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="tf") + + outputs = model(**inputs) + scores = tf.squeeze(outputs.iou_scores) + + self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.8686), atol=1e-4)) + + def test_inference_mask_generation_batched_image_one_point(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + raw_image = prepare_image() + raw_dog_image = prepare_dog_img() + + input_points = [[[820, 1080]], [[220, 470]]] + + inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="tf") + + outputs = model(**inputs) + scores_batched = tf.squeeze(outputs.iou_scores) + + input_points = [[[220, 470]]] + + inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="tf") + + outputs = model(**inputs) + scores_single = tf.squeeze(outputs.iou_scores) + self.assertTrue(np.allclose(scores_batched[1, :].numpy(), scores_single.numpy(), atol=1e-4)) + + def test_inference_mask_generation_two_points_point_batch(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + raw_image = prepare_image() + + # fmt: off + input_points = tf.convert_to_tensor([[[400, 650]], [[220, 470]]]) + # fmt: on + + input_points = tf.expand_dims(input_points, 0) + + inputs = processor(raw_image, input_points=input_points, return_tensors="tf") + + outputs = model(**inputs) + + iou_scores = outputs.iou_scores + self.assertTrue(iou_scores.shape == (1, 2, 3)) + self.assertTrue( + np.allclose( + iou_scores.numpy(), + np.array([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]), + atol=1e-4, + rtol=1e-4, + ) + ) + + def test_inference_mask_generation_three_boxes_point_batch(self): + model = TFSamModel.from_pretrained("facebook/sam-vit-huge") + processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") + + raw_image = prepare_image() + + # fmt: off + input_boxes = tf.convert_to_tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]) + EXPECTED_IOU = np.array([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]]) + # fmt: on + input_boxes = tf.expand_dims(input_boxes, 0) + + inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="tf") + + outputs = model(**inputs) + + iou_scores = outputs.iou_scores + self.assertTrue(iou_scores.shape == (1, 3, 3)) + self.assertTrue(np.allclose(iou_scores.numpy(), EXPECTED_IOU, atol=1e-4, rtol=1e-4)) diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index 13efa22e3e..7d669bb969 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -17,8 +17,14 @@ import unittest import numpy as np -from transformers.testing_utils import require_torch, require_torchvision, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.testing_utils import ( + is_pt_tf_cross_test, + require_tf, + require_torch, + require_torchvision, + require_vision, +) +from transformers.utils import is_tf_available, is_torch_available, is_vision_available if is_vision_available(): @@ -29,6 +35,9 @@ if is_vision_available(): if is_torch_available(): import torch +if is_tf_available(): + import tensorflow as tf + @require_vision @require_torchvision @@ -110,3 +119,158 @@ class SamProcessorTest(unittest.TestCase): dummy_masks = [[1, 0], [0, 1]] with self.assertRaises(ValueError): masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size)) + + +@require_vision +@require_tf +class TFSamProcessorTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = SamImageProcessor() + processor = SamProcessor(image_processor) + processor.save_pretrained(self.tmpdirname) + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + def test_save_load_pretrained_additional_features(self): + processor = SamProcessor(image_processor=self.get_image_processor()) + processor.save_pretrained(self.tmpdirname) + + image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0) + + processor = SamProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0) + + self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.image_processor, SamImageProcessor) + + def test_image_processor(self): + image_processor = self.get_image_processor() + + processor = SamProcessor(image_processor=image_processor) + + image_input = self.prepare_image_inputs() + + input_feat_extract = image_processor(image_input, return_tensors="np") + input_processor = processor(images=image_input, return_tensors="np") + + input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor + input_feat_extract.pop("reshaped_input_sizes") # pop reshaped_input_sizes as it is popped in the processor + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + @require_tf + def test_post_process_masks(self): + image_processor = self.get_image_processor() + + processor = SamProcessor(image_processor=image_processor) + dummy_masks = [tf.ones((1, 3, 5, 5))] + + original_sizes = [[1764, 2646]] + + reshaped_input_size = [[683, 1024]] + masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf") + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + masks = processor.post_process_masks( + dummy_masks, + tf.convert_to_tensor(original_sizes), + tf.convert_to_tensor(reshaped_input_size), + return_tensors="tf", + ) + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + # should also work with np + dummy_masks = [np.ones((1, 3, 5, 5))] + masks = processor.post_process_masks( + dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf" + ) + + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + dummy_masks = [[1, 0], [0, 1]] + with self.assertRaises(tf.errors.InvalidArgumentError): + masks = processor.post_process_masks( + dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf" + ) + + +@require_vision +@require_torchvision +class SamProcessorEquivalenceTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = SamImageProcessor() + processor = SamProcessor(image_processor) + processor.save_pretrained(self.tmpdirname) + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + @is_pt_tf_cross_test + def test_post_process_masks_equivalence(self): + image_processor = self.get_image_processor() + + processor = SamProcessor(image_processor=image_processor) + dummy_masks = np.random.randint(0, 2, size=(1, 3, 5, 5)).astype(np.float32) + tf_dummy_masks = [tf.convert_to_tensor(dummy_masks)] + pt_dummy_masks = [torch.tensor(dummy_masks)] + + original_sizes = [[1764, 2646]] + + reshaped_input_size = [[683, 1024]] + tf_masks = processor.post_process_masks( + tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf" + ) + pt_masks = processor.post_process_masks( + pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt" + ) + + self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy())) + + @is_pt_tf_cross_test + def test_image_processor_equivalence(self): + image_processor = self.get_image_processor() + + processor = SamProcessor(image_processor=image_processor) + + image_input = self.prepare_image_inputs() + + pt_input_feat_extract = image_processor(image_input, return_tensors="pt")["pixel_values"].numpy() + pt_input_processor = processor(images=image_input, return_tensors="pt")["pixel_values"].numpy() + + tf_input_feat_extract = image_processor(image_input, return_tensors="tf")["pixel_values"].numpy() + tf_input_processor = processor(images=image_input, return_tensors="tf")["pixel_values"].numpy() + + self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor)) + self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract)) + self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor))