TF port of the Segment Anything Model (SAM) (#22970)
* First commit * Add auto-translation with GPT-4 * make fixup * Add a functional layernorm for TF * Add all the auxiliary imports etc. * Add the extra processor and tests * rebase to main * Add all the needed fixes to the GPT code * make fixup * Make convolutions channels-last so they run on CPU * make fixup * Fix final issues * Fix other models affected by test change * Clarify comment on the sparse_prompt_embeddings check * Refactor functional_layernorm, use shape_list in place of .shape in some places * Remove deprecated torch-alike code * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Refactor processor with common methods and separated private methods * make fixup * Quietly delete the file that didn't do anything (sorry Sylvain) * Refactor the processor tests into one file * make fixup * Clean up some unnecessary indirection * Fix TF mask postprocessing * Add more processor equivalence tests * Refactor generate_crop_boxes to use framework-neutral np code * Make the serving output correctly conditional * Fix error message line length * Use dict keys rather than indices internally in both TF and PT SAM call/forward * Return dicts internally in the call/forward methods * Revert changes to common tests and just override check_pt_tf_outputs * Revert changes to other model tests * Clarify comments for functional layernorm * Add missing transpose from PT code * Removed unused copied from in PT code * Remove overrides for tests that don't exist in TF * Fix transpose and update tests for PT and TF to check pred_masks * Add training flag * Update tests to use TF checkpoints * Update index.mdx * Add missing cross-test decorator * Remove optional extra asterisks * Revert return_dict changes in PT code * Update src/transformers/models/sam/modeling_tf_sam.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove None return annotations on init methods * Update tests/models/sam/test_processor_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix input_boxes shapes * make fixup --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RWKV | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SAM | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SAM | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
||||
@@ -99,3 +99,9 @@ Resources:
|
||||
|
||||
[[autodoc]] SamModel
|
||||
- forward
|
||||
|
||||
|
||||
## TFSamModel
|
||||
|
||||
[[autodoc]] TFSamModel
|
||||
- call
|
||||
@@ -3406,6 +3406,13 @@ else:
|
||||
"TFRoFormerPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.sam"].extend(
|
||||
[
|
||||
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFSamModel",
|
||||
"TFSamPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.segformer"].extend(
|
||||
[
|
||||
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@@ -6657,6 +6664,11 @@ if TYPE_CHECKING:
|
||||
TFRoFormerModel,
|
||||
TFRoFormerPreTrainedModel,
|
||||
)
|
||||
from .models.sam import (
|
||||
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFSamModel,
|
||||
TFSamPreTrainedModel,
|
||||
)
|
||||
from .models.segformer import (
|
||||
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFSegformerDecodeHead,
|
||||
|
||||
@@ -76,6 +76,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("roberta", "TFRobertaModel"),
|
||||
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
||||
("roformer", "TFRoFormerModel"),
|
||||
("sam", "TFSamModel"),
|
||||
("segformer", "TFSegformerModel"),
|
||||
("speech_to_text", "TFSpeech2TextModel"),
|
||||
("swin", "TFSwinModel"),
|
||||
@@ -426,6 +427,11 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
||||
("mobilebert", "TFMobileBertForNextSentencePrediction"),
|
||||
]
|
||||
)
|
||||
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("sam", "TFSamModel"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
@@ -476,6 +482,14 @@ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
|
||||
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
|
||||
|
||||
|
||||
class TFAutoModel(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_MAPPING
|
||||
|
||||
@@ -13,7 +13,13 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@@ -39,6 +45,17 @@ else:
|
||||
"SamModel",
|
||||
"SamPreTrainedModel",
|
||||
]
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_sam"] = [
|
||||
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFSamModel",
|
||||
"TFSamPreTrainedModel",
|
||||
]
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -66,6 +83,14 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -34,7 +34,14 @@ from ...image_utils import (
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -44,6 +51,12 @@ if is_torch_available():
|
||||
if is_torchvision_available():
|
||||
from torchvision.ops.boxes import batched_nms
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from tensorflow.experimental import numpy as tnp
|
||||
|
||||
from ...tf_utils import flatten, shape_list
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -372,6 +385,61 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
return encoded_outputs
|
||||
|
||||
def post_process_masks(
|
||||
self,
|
||||
masks,
|
||||
original_sizes,
|
||||
reshaped_input_sizes,
|
||||
mask_threshold=0.0,
|
||||
binarize=True,
|
||||
pad_size=None,
|
||||
return_tensors="pt",
|
||||
):
|
||||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Args:
|
||||
masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`):
|
||||
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
||||
original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
|
||||
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
||||
width) format.
|
||||
reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
|
||||
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||||
The threshold to use for binarizing the masks.
|
||||
binarize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to binarize the masks.
|
||||
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
||||
The target size the images were padded to before being passed to the model. If None, the target size is
|
||||
assumed to be the processor's `pad_size`.
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
|
||||
Returns:
|
||||
(`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
|
||||
(height, width) is given by original_size.
|
||||
"""
|
||||
if return_tensors == "pt":
|
||||
return self._post_process_masks_pt(
|
||||
masks=masks,
|
||||
original_sizes=original_sizes,
|
||||
reshaped_input_sizes=reshaped_input_sizes,
|
||||
mask_threshold=mask_threshold,
|
||||
binarize=binarize,
|
||||
pad_size=pad_size,
|
||||
)
|
||||
elif return_tensors == "tf":
|
||||
return self._post_process_masks_tf(
|
||||
masks=masks,
|
||||
original_sizes=original_sizes,
|
||||
reshaped_input_sizes=reshaped_input_sizes,
|
||||
mask_threshold=mask_threshold,
|
||||
binarize=binarize,
|
||||
pad_size=pad_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError("return_tensors must be either 'pt' or 'tf'")
|
||||
|
||||
def _post_process_masks_pt(
|
||||
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
||||
):
|
||||
"""
|
||||
@@ -418,21 +486,70 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
|
||||
return output_masks
|
||||
|
||||
def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
|
||||
def _post_process_masks_tf(
|
||||
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
||||
):
|
||||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Args:
|
||||
masks (`tf.Tensor`):
|
||||
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
||||
original_sizes (`tf.Tensor`):
|
||||
The original size of the images before resizing for input to the model, in (height, width) format.
|
||||
reshaped_input_sizes (`tf.Tensor`):
|
||||
The size of the image input to the model, in (height, width) format. Used to remove padding.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||||
The threshold to use for binarizing the masks.
|
||||
binarize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to binarize the masks.
|
||||
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
||||
The target size the images were padded to before being passed to the model. If None, the target size is
|
||||
assumed to be the processor's `pad_size`.
|
||||
Returns:
|
||||
(`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
|
||||
given by original_size.
|
||||
"""
|
||||
requires_backends(self, ["tf"])
|
||||
pad_size = self.pad_size if pad_size is None else pad_size
|
||||
target_image_size = (pad_size["height"], pad_size["width"])
|
||||
|
||||
output_masks = []
|
||||
for i, original_size in enumerate(original_sizes):
|
||||
# tf.image expects NHWC, we transpose the NCHW inputs for it
|
||||
mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
|
||||
interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
|
||||
interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
|
||||
interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
|
||||
if binarize:
|
||||
interpolated_mask = interpolated_mask > mask_threshold
|
||||
# And then we transpose them back at the end
|
||||
output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
|
||||
|
||||
return output_masks
|
||||
|
||||
def post_process_for_mask_generation(
|
||||
self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"
|
||||
):
|
||||
"""
|
||||
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
|
||||
|
||||
Args:
|
||||
all_masks (`List[torch.Tensor]`):
|
||||
all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`):
|
||||
List of all predicted segmentation masks
|
||||
all_scores (`List[torch.Tensor]`):
|
||||
all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`):
|
||||
List of all predicted iou scores
|
||||
all_boxes (`List[torch.Tensor]`):
|
||||
all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`):
|
||||
List of all bounding boxes of the predicted masks
|
||||
crops_nms_thresh (`float`):
|
||||
Threshold for NMS (Non Maximum Suppression) algorithm.
|
||||
return_tensors (`str`, *optional*, defaults to `pt`):
|
||||
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
||||
"""
|
||||
if return_tensors == "pt":
|
||||
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
||||
elif return_tensors == "tf":
|
||||
return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
||||
|
||||
def generate_crop_boxes(
|
||||
self,
|
||||
@@ -443,6 +560,7 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
points_per_crop: Optional[int] = 32,
|
||||
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
||||
device: Optional["torch.device"] = None,
|
||||
return_tensors: str = "pt",
|
||||
):
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||
@@ -464,10 +582,35 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
device (`torch.device`, *optional*, defaults to None):
|
||||
Device to use for the computation. If None, cpu will be used.
|
||||
return_tensors (`str`, *optional*, defaults to `pt`):
|
||||
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
||||
"""
|
||||
return _generate_crop_boxes(
|
||||
image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device
|
||||
crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
|
||||
image,
|
||||
target_size,
|
||||
crop_n_layers,
|
||||
overlap_ratio,
|
||||
points_per_crop,
|
||||
crop_n_points_downscale_factor,
|
||||
)
|
||||
if return_tensors == "pt":
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
crop_boxes = torch.tensor(crop_boxes, device=device)
|
||||
points_per_crop = torch.tensor(points_per_crop, device=device)
|
||||
# cropped_images stays as np
|
||||
input_labels = torch.tensor(input_labels, device=device)
|
||||
|
||||
elif return_tensors == "tf":
|
||||
if device is not None:
|
||||
raise ValueError("device is not a supported argument when return_tensors is tf!")
|
||||
crop_boxes = tf.convert_to_tensor(crop_boxes)
|
||||
points_per_crop = tf.convert_to_tensor(points_per_crop)
|
||||
# cropped_images stays as np
|
||||
input_labels = tf.convert_to_tensor(input_labels)
|
||||
else:
|
||||
raise ValueError("return_tensors must be either 'pt' or 'tf'.")
|
||||
return crop_boxes, points_per_crop, cropped_images, input_labels
|
||||
|
||||
def filter_masks(
|
||||
self,
|
||||
@@ -479,6 +622,67 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
stability_score_thresh=0.95,
|
||||
mask_threshold=0,
|
||||
stability_score_offset=1,
|
||||
return_tensors="pt",
|
||||
):
|
||||
"""
|
||||
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
||||
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
||||
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
||||
bounding boxes and pad the predicted masks if necessary.
|
||||
|
||||
Args:
|
||||
masks (`Union[torch.Tensor, tf.Tensor]`):
|
||||
Input masks.
|
||||
iou_scores (`Union[torch.Tensor, tf.Tensor]`):
|
||||
List of IoU scores.
|
||||
original_size (`Tuple[int,int]`):
|
||||
Size of the orginal image.
|
||||
cropped_box_image (`np.array`):
|
||||
The cropped image.
|
||||
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
||||
The threshold for the iou scores.
|
||||
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
||||
The threshold for the stability score.
|
||||
mask_threshold (`float`, *optional*, defaults to 0):
|
||||
The threshold for the predicted masks.
|
||||
stability_score_offset (`float`, *optional*, defaults to 1):
|
||||
The offset for the stability score used in the `_compute_stability_score` method.
|
||||
return_tensors (`str`, *optional*, defaults to `pt`):
|
||||
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
||||
"""
|
||||
if return_tensors == "pt":
|
||||
return self._filter_masks_pt(
|
||||
masks=masks,
|
||||
iou_scores=iou_scores,
|
||||
original_size=original_size,
|
||||
cropped_box_image=cropped_box_image,
|
||||
pred_iou_thresh=pred_iou_thresh,
|
||||
stability_score_thresh=stability_score_thresh,
|
||||
mask_threshold=mask_threshold,
|
||||
stability_score_offset=stability_score_offset,
|
||||
)
|
||||
elif return_tensors == "tf":
|
||||
return self._filter_masks_tf(
|
||||
masks=masks,
|
||||
iou_scores=iou_scores,
|
||||
original_size=original_size,
|
||||
cropped_box_image=cropped_box_image,
|
||||
pred_iou_thresh=pred_iou_thresh,
|
||||
stability_score_thresh=stability_score_thresh,
|
||||
mask_threshold=mask_threshold,
|
||||
stability_score_offset=stability_score_offset,
|
||||
)
|
||||
|
||||
def _filter_masks_pt(
|
||||
self,
|
||||
masks,
|
||||
iou_scores,
|
||||
original_size,
|
||||
cropped_box_image,
|
||||
pred_iou_thresh=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
mask_threshold=0,
|
||||
stability_score_offset=1,
|
||||
):
|
||||
"""
|
||||
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
||||
@@ -525,7 +729,7 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
|
||||
# compute stability score
|
||||
if stability_score_thresh > 0.0:
|
||||
stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
|
||||
stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
|
||||
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
||||
|
||||
scores = iou_scores[keep_mask]
|
||||
@@ -549,8 +753,85 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
|
||||
return masks, scores, converted_boxes
|
||||
|
||||
def _filter_masks_tf(
|
||||
self,
|
||||
masks,
|
||||
iou_scores,
|
||||
original_size,
|
||||
cropped_box_image,
|
||||
pred_iou_thresh=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
mask_threshold=0,
|
||||
stability_score_offset=1,
|
||||
):
|
||||
"""
|
||||
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
||||
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
||||
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
||||
bounding boxes and pad the predicted masks if necessary.
|
||||
|
||||
def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
|
||||
Args:
|
||||
masks (`tf.Tensor`):
|
||||
Input masks.
|
||||
iou_scores (`tf.Tensor`):
|
||||
List of IoU scores.
|
||||
original_size (`Tuple[int,int]`):
|
||||
Size of the orginal image.
|
||||
cropped_box_image (`np.array`):
|
||||
The cropped image.
|
||||
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
||||
The threshold for the iou scores.
|
||||
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
||||
The threshold for the stability score.
|
||||
mask_threshold (`float`, *optional*, defaults to 0):
|
||||
The threshold for the predicted masks.
|
||||
stability_score_offset (`float`, *optional*, defaults to 1):
|
||||
The offset for the stability score used in the `_compute_stability_score` method.
|
||||
|
||||
"""
|
||||
requires_backends(self, ["tf"])
|
||||
original_height, original_width = original_size
|
||||
iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
|
||||
masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])
|
||||
|
||||
if masks.shape[0] != iou_scores.shape[0]:
|
||||
raise ValueError("masks and iou_scores must have the same batch size.")
|
||||
|
||||
batch_size = masks.shape[0]
|
||||
|
||||
keep_mask = tf.ones(batch_size, dtype=tf.bool)
|
||||
|
||||
if pred_iou_thresh > 0.0:
|
||||
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
|
||||
|
||||
# compute stability score
|
||||
if stability_score_thresh > 0.0:
|
||||
stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
|
||||
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
||||
|
||||
scores = iou_scores[keep_mask]
|
||||
masks = masks[keep_mask]
|
||||
|
||||
# binarize masks
|
||||
masks = masks > mask_threshold
|
||||
converted_boxes = _batched_mask_to_box_tf(masks)
|
||||
|
||||
keep_mask = ~_is_box_near_crop_edge_tf(
|
||||
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
|
||||
)
|
||||
|
||||
scores = scores[keep_mask]
|
||||
masks = masks[keep_mask]
|
||||
converted_boxes = converted_boxes[keep_mask]
|
||||
|
||||
masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)
|
||||
# conversion to rle is necessary to run non-maximum suppresion
|
||||
masks = _mask_to_rle_tf(masks)
|
||||
|
||||
return masks, scores, converted_boxes
|
||||
|
||||
|
||||
def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecesary cast to torch.int64
|
||||
intersections = (
|
||||
@@ -561,6 +842,17 @@ def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stabi
|
||||
return stability_scores
|
||||
|
||||
|
||||
def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
|
||||
# Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure
|
||||
# we get the right division results.
|
||||
intersections = tf.count_nonzero(
|
||||
masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
|
||||
)
|
||||
unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
|
||||
stability_scores = intersections / unions
|
||||
return stability_scores
|
||||
|
||||
|
||||
def _build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
@@ -606,7 +898,6 @@ def _generate_crop_boxes(
|
||||
overlap_ratio: float = 512 / 1500,
|
||||
points_per_crop: Optional[int] = 32,
|
||||
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
||||
device: Optional["torch.device"] = None,
|
||||
) -> Tuple[List[List[int]], List[int]]:
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||
@@ -626,11 +917,7 @@ def _generate_crop_boxes(
|
||||
Number of points to sample per crop.
|
||||
crop_n_points_downscale_factor (`int`, *optional*):
|
||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
device (`torch.device`, *optional*):
|
||||
Device to run the crop generation on. Defaults to CPU.
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if isinstance(image, list):
|
||||
raise ValueError("Only one image is allowed for crop generation.")
|
||||
@@ -648,12 +935,11 @@ def _generate_crop_boxes(
|
||||
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
||||
)
|
||||
|
||||
crop_boxes = torch.tensor(crop_boxes, dtype=torch.float32, device=device)
|
||||
point_grid_per_crop = np.array([point_grid_per_crop])
|
||||
points_per_crop = torch.tensor(point_grid_per_crop, device=device)
|
||||
points_per_crop = points_per_crop.permute(0, 2, 1, 3)
|
||||
crop_boxes = crop_boxes.astype(np.float32)
|
||||
points_per_crop = np.array([point_grid_per_crop])
|
||||
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
|
||||
|
||||
input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.long, device=device)
|
||||
input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
|
||||
|
||||
return crop_boxes, points_per_crop, cropped_images, input_labels
|
||||
|
||||
@@ -730,6 +1016,16 @@ def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
|
||||
return torch.nn.functional.pad(masks, pad, value=0)
|
||||
|
||||
|
||||
def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int):
|
||||
left, top, right, bottom = crop_box
|
||||
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
|
||||
return masks
|
||||
# Coordinate transform masks
|
||||
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
|
||||
pad = (left, pad_x - left, top, pad_y - top)
|
||||
return tf.pad(masks, pad, constant_values=0)
|
||||
|
||||
|
||||
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
||||
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||
@@ -748,6 +1044,24 @@ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
||||
return torch.any(near_crop_edge, dim=1)
|
||||
|
||||
|
||||
def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
|
||||
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||
crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
|
||||
orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)
|
||||
|
||||
left, top, _, _ = crop_box
|
||||
offset = tf.convert_to_tensor([[left, top, left, top]])
|
||||
# Check if boxes has a channel dimension
|
||||
if len(boxes.shape) == 3:
|
||||
offset = tf.expand_dims(offset, 1)
|
||||
boxes = tf.cast(boxes + offset, tf.float32)
|
||||
|
||||
near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
|
||||
near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
|
||||
near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
|
||||
return tf.reduce_any(near_crop_edge, axis=1)
|
||||
|
||||
|
||||
def _batched_mask_to_box(masks: "torch.Tensor"):
|
||||
"""
|
||||
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
||||
@@ -797,6 +1111,54 @@ def _batched_mask_to_box(masks: "torch.Tensor"):
|
||||
return out
|
||||
|
||||
|
||||
def _batched_mask_to_box_tf(masks: "tf.Tensor"):
|
||||
"""
|
||||
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
||||
corresponds the following required indices:
|
||||
- LEFT: left hand side of the bounding box
|
||||
- TOP: top of the bounding box
|
||||
- RIGHT: right of the bounding box
|
||||
- BOTTOM: bottom of the bounding box
|
||||
|
||||
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
|
||||
is channel_1 x channel_2 x ... x 4.
|
||||
|
||||
Args:
|
||||
- masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`)
|
||||
"""
|
||||
|
||||
if tf.size(masks) == 0:
|
||||
return tf.zeros([*masks.shape[:-2], 4])
|
||||
|
||||
# Normalize shape to Cxheightxwidth
|
||||
shape = shape_list(masks)
|
||||
height, width = shape[-2:]
|
||||
|
||||
# Get top and bottom edges
|
||||
in_height = tf.reduce_max(masks, axis=-1)
|
||||
in_height_coords = in_height * tf.range(height)[None, :]
|
||||
bottom_edges = tf.reduce_max(in_height_coords, axis=-1)
|
||||
in_height_coords = in_height_coords + height * (~in_height)
|
||||
top_edges = tf.reduce_min(in_height_coords, axis=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width, _ = tf.reduce_max(masks, axis=-2)
|
||||
in_width_coords = in_width * tf.range(width)[None, :]
|
||||
right_edges, _ = tf.reduce_max(in_width_coords, axis=-1)
|
||||
in_width_coords = in_width_coords + width * (~in_width)
|
||||
left_edges, _ = tf.reduce_min(in_width_coords, axis=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
|
||||
out = out * tf.expand_dims(~empty_filter, -1)
|
||||
|
||||
# Return to original shape
|
||||
out = tf.reshape(out, *shape[:-2], 4)
|
||||
return out
|
||||
|
||||
|
||||
def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
||||
"""
|
||||
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
||||
@@ -820,6 +1182,29 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
||||
return out
|
||||
|
||||
|
||||
def _mask_to_rle_tf(input_mask: "tf.Tensor"):
|
||||
"""
|
||||
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
||||
"""
|
||||
# Put in fortran order and flatten height and width
|
||||
batch_size, height, width = input_mask.shape
|
||||
input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)
|
||||
|
||||
# Compute change indices
|
||||
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
|
||||
change_indices = tf.where(diff)
|
||||
|
||||
# Encode run length
|
||||
out = []
|
||||
for i in range(batch_size):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if input_mask[i, 0] == 0 else [0]
|
||||
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
|
||||
out.append({"size": [height, width], "counts": counts})
|
||||
return out
|
||||
|
||||
|
||||
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
||||
"""Compute a binary mask from an uncompressed RLE."""
|
||||
height, width = rle["size"]
|
||||
@@ -836,7 +1221,7 @@ def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
||||
|
||||
def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
||||
"""
|
||||
Perform NMS (Non Maxium Suppression) on the outputs.
|
||||
Perform NMS (Non Maximum Suppression) on the outputs.
|
||||
|
||||
Args:
|
||||
rle_masks (`torch.Tensor`):
|
||||
@@ -861,3 +1246,32 @@ def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=
|
||||
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
||||
|
||||
return masks, iou_scores, rle_masks, mask_boxes
|
||||
|
||||
|
||||
def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
||||
"""
|
||||
Perform NMS (Non Maximum Suppression) on the outputs.
|
||||
|
||||
Args:
|
||||
rle_masks (`tf.Tensor`):
|
||||
binary masks in the RLE format
|
||||
iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
|
||||
iou_scores predicted by the model
|
||||
mask_boxes (`tf.Tensor`):
|
||||
The bounding boxes corresponding to segmentation masks
|
||||
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
||||
NMS threshold.
|
||||
"""
|
||||
keep_by_nms = tf.image.combined_non_max_suppression(
|
||||
boxes=mask_boxes.float(),
|
||||
scores=iou_scores,
|
||||
idxs=torch.zeros(mask_boxes.shape[0]),
|
||||
iou_threshold=amg_crops_nms_thresh,
|
||||
)
|
||||
|
||||
iou_scores = iou_scores[keep_by_nms]
|
||||
rle_masks = [rle_masks[i] for i in keep_by_nms]
|
||||
mask_boxes = mask_boxes[keep_by_nms]
|
||||
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
||||
|
||||
return masks, iou_scores, rle_masks, mask_boxes
|
||||
|
||||
@@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput):
|
||||
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
# Copied from src.models.modeling_vit_mae.ViTMAEPatchEmbeddings with ViTMAEPatchEmbeddings->SamVisionEmbeddings,x->embeddings
|
||||
class SamPatchEmbeddings(nn.Module):
|
||||
"""
|
||||
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||
@@ -198,7 +197,7 @@ class SamAttention(nn.Module):
|
||||
values.
|
||||
"""
|
||||
|
||||
def __init__(self, config, downsample_rate=None) -> None:
|
||||
def __init__(self, config, downsample_rate=None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
@@ -252,7 +251,7 @@ class SamAttention(nn.Module):
|
||||
|
||||
|
||||
class SamTwoWayAttentionBlock(nn.Module):
|
||||
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False) -> None:
|
||||
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
|
||||
"""
|
||||
A transformer block with four layers:
|
||||
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
|
||||
@@ -476,7 +475,7 @@ class SamMaskDecoder(nn.Module):
|
||||
the embeddings of the mask inputs
|
||||
multimask_output (bool):
|
||||
Whether to return multiple masks or a single mask.
|
||||
output_attentions (bool, **optional**):
|
||||
output_attentions (bool, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers.
|
||||
"""
|
||||
batch_size, num_channels, height, width = image_embeddings.shape
|
||||
@@ -668,11 +667,11 @@ class SamPromptEncoder(nn.Module):
|
||||
Embeds different types of prompts, returning both sparse and dense embeddings.
|
||||
|
||||
Args:
|
||||
points (`torch.Tensor`, **optionnal**):
|
||||
points (`torch.Tensor`, *optional*):
|
||||
point coordinates and labels to embed.
|
||||
boxes (`torch.Tensor`, **optionnal**):
|
||||
boxes (`torch.Tensor`, *optional*):
|
||||
boxes to embed
|
||||
masks (`torch.Tensor`, **optionnal**):
|
||||
masks (`torch.Tensor`, *optional*):
|
||||
masks to embed
|
||||
"""
|
||||
sparse_embeddings = None
|
||||
@@ -707,7 +706,7 @@ class SamPromptEncoder(nn.Module):
|
||||
class SamVisionAttention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
def __init__(self, config, window_size) -> None:
|
||||
def __init__(self, config, window_size):
|
||||
super().__init__()
|
||||
input_size = (
|
||||
(config.image_size // config.patch_size, config.image_size // config.patch_size)
|
||||
@@ -845,7 +844,7 @@ class SamVisionAttention(nn.Module):
|
||||
|
||||
|
||||
class SamVisionLayer(nn.Module):
|
||||
def __init__(self, config, window_size) -> None:
|
||||
def __init__(self, config, window_size):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attn = SamVisionAttention(config, window_size)
|
||||
@@ -1166,7 +1165,7 @@ SAM_INPUTS_DOCSTRING = r"""
|
||||
class SamModel(SamPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
|
||||
|
||||
@@ -1334,7 +1333,6 @@ class SamModel(SamPreTrainedModel):
|
||||
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
|
||||
|
||||
vision_attentions = None
|
||||
mask_decoder_attentions = None
|
||||
vision_hidden_states = None
|
||||
|
||||
if pixel_values is not None:
|
||||
@@ -1359,7 +1357,8 @@ class SamModel(SamPreTrainedModel):
|
||||
"The batch size of the image embeddings and the input points must be the same. ",
|
||||
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
|
||||
" if you want to pass multiple points for the same image, make sure that you passed ",
|
||||
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
||||
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
|
||||
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
||||
)
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
||||
|
||||
1497
src/transformers/models/sam/modeling_tf_sam.py
Normal file
1497
src/transformers/models/sam/modeling_tf_sam.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -22,12 +22,15 @@ import numpy as np
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import TensorType, is_torch_available
|
||||
from ...utils import TensorType, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class SamProcessor(ProcessorMixin):
|
||||
r"""
|
||||
@@ -72,7 +75,7 @@ class SamProcessor(ProcessorMixin):
|
||||
# pop arguments that are not used in the foward but used nevertheless
|
||||
original_sizes = encoding_image_processor["original_sizes"]
|
||||
|
||||
if isinstance(original_sizes, torch.Tensor):
|
||||
if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor
|
||||
original_sizes = original_sizes.numpy()
|
||||
|
||||
input_points, input_labels, input_boxes = self._check_and_preprocess_points(
|
||||
@@ -139,18 +142,30 @@ class SamProcessor(ProcessorMixin):
|
||||
input_boxes = torch.from_numpy(input_boxes)
|
||||
# boxes batch size of 1 by default
|
||||
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
|
||||
elif return_tensors == "tf":
|
||||
input_boxes = tf.convert_to_tensor(input_boxes)
|
||||
# boxes batch size of 1 by default
|
||||
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
|
||||
encoding_image_processor.update({"input_boxes": input_boxes})
|
||||
if input_points is not None:
|
||||
if return_tensors == "pt":
|
||||
input_points = torch.from_numpy(input_points)
|
||||
# point batch size of 1 by default
|
||||
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
|
||||
elif return_tensors == "tf":
|
||||
input_points = tf.convert_to_tensor(input_points)
|
||||
# point batch size of 1 by default
|
||||
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
|
||||
encoding_image_processor.update({"input_points": input_points})
|
||||
if input_labels is not None:
|
||||
if return_tensors == "pt":
|
||||
input_labels = torch.from_numpy(input_labels)
|
||||
# point batch size of 1 by default
|
||||
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
|
||||
elif return_tensors == "tf":
|
||||
input_labels = tf.convert_to_tensor(input_labels)
|
||||
# point batch size of 1 by default
|
||||
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
|
||||
encoding_image_processor.update({"input_labels": input_labels})
|
||||
|
||||
return encoding_image_processor
|
||||
@@ -204,7 +219,7 @@ class SamProcessor(ProcessorMixin):
|
||||
it is converted to a `numpy.ndarray` and then to a `list`.
|
||||
"""
|
||||
if input_points is not None:
|
||||
if isinstance(input_points, torch.Tensor):
|
||||
if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor
|
||||
input_points = input_points.numpy().tolist()
|
||||
|
||||
if not isinstance(input_points, list) or not isinstance(input_points[0], list):
|
||||
@@ -214,7 +229,7 @@ class SamProcessor(ProcessorMixin):
|
||||
input_points = None
|
||||
|
||||
if input_labels is not None:
|
||||
if isinstance(input_labels, torch.Tensor):
|
||||
if hasattr(input_labels, "numpy"):
|
||||
input_labels = input_labels.numpy().tolist()
|
||||
|
||||
if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
|
||||
@@ -224,7 +239,7 @@ class SamProcessor(ProcessorMixin):
|
||||
input_labels = None
|
||||
|
||||
if input_boxes is not None:
|
||||
if isinstance(input_boxes, torch.Tensor):
|
||||
if hasattr(input_boxes, "numpy"):
|
||||
input_boxes = input_boxes.numpy().tolist()
|
||||
|
||||
if (
|
||||
|
||||
@@ -70,6 +70,56 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional
|
||||
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
|
||||
|
||||
|
||||
def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
|
||||
# This is a very simplified functional layernorm, designed to duplicate
|
||||
# the functionality of PyTorch nn.functional.layer_norm when this is needed to port
|
||||
# models in Transformers.
|
||||
|
||||
if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int):
|
||||
raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.")
|
||||
|
||||
# Get mean and variance on the axis to be normalized
|
||||
mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True)
|
||||
|
||||
if axis != -1:
|
||||
# Reshape scale and weight to have the same rank as inputs, but with 1 dimensions
|
||||
# on every dimension except axis
|
||||
shape = [1] * inputs.shape.rank
|
||||
shape[axis] = shape_list(inputs)[axis]
|
||||
weight = tf.reshape(weight, shape)
|
||||
bias = tf.reshape(bias, shape)
|
||||
|
||||
# Compute layer normalization using the batch_normalization
|
||||
# function.
|
||||
outputs = tf.nn.batch_normalization(
|
||||
inputs,
|
||||
mean,
|
||||
variance,
|
||||
offset=bias,
|
||||
scale=weight,
|
||||
variance_epsilon=epsilon,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def flatten(input, start_dim=0, end_dim=-1):
|
||||
# Replicates the behavior of torch.flatten in TF
|
||||
|
||||
# If end_dim or start_dim is negative, count them from the end
|
||||
if end_dim < 0:
|
||||
end_dim += input.shape.rank
|
||||
if start_dim < 0:
|
||||
start_dim += input.shape.rank
|
||||
|
||||
if start_dim == end_dim:
|
||||
return input
|
||||
|
||||
in_shape = tf.shape(input)
|
||||
flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1])
|
||||
out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0)
|
||||
return tf.reshape(input, out_shape)
|
||||
|
||||
|
||||
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Invert an attention mask (e.g., switches 0. and 1.).
|
||||
|
||||
@@ -2317,6 +2317,23 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFSamModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSamPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
@@ -436,6 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4):
|
||||
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
@@ -470,8 +473,10 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4))
|
||||
self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4))
|
||||
|
||||
def test_inference_mask_generation_one_point_one_bb(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
@@ -491,8 +496,12 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze()
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4))
|
||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4))
|
||||
self.assertTrue(
|
||||
torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4)
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_batched_points_batched_images(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
@@ -514,6 +523,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
scores = outputs.iou_scores.squeeze().cpu()
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
|
||||
|
||||
EXPECTED_SCORES = torch.tensor(
|
||||
[
|
||||
@@ -531,7 +541,9 @@ class SamModelIntegrationTest(unittest.TestCase):
|
||||
],
|
||||
]
|
||||
)
|
||||
EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406])
|
||||
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
|
||||
|
||||
def test_inference_mask_generation_one_point_one_bb_zero(self):
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
671
tests/models/sam/test_modeling_tf_sam.py
Normal file
671
tests/models/sam/test_modeling_tf_sam.py
Normal file
@@ -0,0 +1,671 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TensorFlow SAM model. """
|
||||
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
from transformers.utils import is_tf_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import SamProcessor, TFSamModel
|
||||
from transformers.models.sam.modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class TFSamPromptEncoderTester:
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=32,
|
||||
input_image_size=24,
|
||||
patch_size=2,
|
||||
mask_input_channels=4,
|
||||
num_point_embeddings=4,
|
||||
hidden_act="gelu",
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.input_image_size = input_image_size
|
||||
self.patch_size = patch_size
|
||||
self.mask_input_channels = mask_input_channels
|
||||
self.num_point_embeddings = num_point_embeddings
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
def get_config(self):
|
||||
return SamPromptEncoderConfig(
|
||||
image_size=self.input_image_size,
|
||||
patch_size=self.patch_size,
|
||||
mask_input_channels=self.mask_input_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_point_embeddings=self.num_point_embeddings,
|
||||
hidden_act=self.hidden_act,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
dummy_points = floats_tensor([self.batch_size, 3, 2])
|
||||
config = self.get_config()
|
||||
|
||||
return config, dummy_points
|
||||
|
||||
|
||||
class TFSamMaskDecoderTester:
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=32,
|
||||
hidden_act="relu",
|
||||
mlp_dim=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
attention_downsample_rate=2,
|
||||
num_multimask_outputs=3,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=32,
|
||||
layer_norm_eps=1e-6,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.mlp_dim = mlp_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_downsample_rate = attention_downsample_rate
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
self.iou_head_depth = iou_head_depth
|
||||
self.iou_head_hidden_dim = iou_head_hidden_dim
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
def get_config(self):
|
||||
return SamMaskDecoderConfig(
|
||||
hidden_size=self.hidden_size,
|
||||
hidden_act=self.hidden_act,
|
||||
mlp_dim=self.mlp_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
attention_downsample_rate=self.attention_downsample_rate,
|
||||
num_multimask_outputs=self.num_multimask_outputs,
|
||||
iou_head_depth=self.iou_head_depth,
|
||||
iou_head_hidden_dim=self.iou_head_hidden_dim,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
config = self.get_config()
|
||||
|
||||
dummy_inputs = {
|
||||
"image_embedding": floats_tensor([self.batch_size, self.hidden_size]),
|
||||
}
|
||||
|
||||
return config, dummy_inputs
|
||||
|
||||
|
||||
class TFSamModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
hidden_size=36,
|
||||
intermediate_size=72,
|
||||
projection_dim=62,
|
||||
output_channels=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_channels=3,
|
||||
image_size=24,
|
||||
patch_size=2,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-06,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
qkv_bias=True,
|
||||
mlp_ratio=4.0,
|
||||
use_abs_pos=True,
|
||||
use_rel_pos=True,
|
||||
rel_pos_zero_init=False,
|
||||
window_size=14,
|
||||
global_attn_indexes=[2, 5, 8, 11],
|
||||
num_pos_feats=16,
|
||||
mlp_dim=None,
|
||||
batch_size=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.output_channels = output_channels
|
||||
self.num_channels = num_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
self.hidden_act = hidden_act
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.qkv_bias = qkv_bias
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.use_abs_pos = use_abs_pos
|
||||
self.use_rel_pos = use_rel_pos
|
||||
self.rel_pos_zero_init = rel_pos_zero_init
|
||||
self.window_size = window_size
|
||||
self.global_attn_indexes = global_attn_indexes
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.mlp_dim = mlp_dim
|
||||
self.batch_size = batch_size
|
||||
|
||||
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 1
|
||||
|
||||
self.prompt_encoder_tester = TFSamPromptEncoderTester()
|
||||
self.mask_decoder_tester = TFSamMaskDecoderTester()
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
vision_config = SamVisionConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
initializer_range=self.initializer_range,
|
||||
initializer_factor=self.initializer_factor,
|
||||
output_channels=self.output_channels,
|
||||
qkv_bias=self.qkv_bias,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
use_abs_pos=self.use_abs_pos,
|
||||
use_rel_pos=self.use_rel_pos,
|
||||
rel_pos_zero_init=self.rel_pos_zero_init,
|
||||
window_size=self.window_size,
|
||||
global_attn_indexes=self.global_attn_indexes,
|
||||
num_pos_feats=self.num_pos_feats,
|
||||
mlp_dim=self.mlp_dim,
|
||||
)
|
||||
|
||||
prompt_encoder_config = self.prompt_encoder_tester.get_config()
|
||||
|
||||
mask_decoder_config = self.mask_decoder_tester.get_config()
|
||||
|
||||
return SamConfig(
|
||||
vision_config=vision_config,
|
||||
prompt_encoder_config=prompt_encoder_config,
|
||||
mask_decoder_config=mask_decoder_config,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = TFSamModel(config=config)
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
|
||||
self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3))
|
||||
|
||||
def create_and_check_get_image_features(self, config, pixel_values):
|
||||
model = TFSamModel(config=config)
|
||||
result = model.get_image_embeddings(pixel_values)
|
||||
self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12))
|
||||
|
||||
def create_and_check_get_image_hidden_states(self, config, pixel_values):
|
||||
model = TFSamModel(config=config)
|
||||
result = model.vision_encoder(
|
||||
pixel_values,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# after computing the convolutional features
|
||||
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
|
||||
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
|
||||
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
|
||||
|
||||
result = model.vision_encoder(
|
||||
pixel_values,
|
||||
output_hidden_states=True,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# after computing the convolutional features
|
||||
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
|
||||
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
|
||||
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (TFSamModel,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": TFSamModel, "mask-generation": TFSamModel} if is_tf_available() else {}
|
||||
)
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
):
|
||||
return True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFSamModelTester(self)
|
||||
self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
|
||||
self.prompt_encoder_config_tester = ConfigTester(
|
||||
self,
|
||||
config_class=SamPromptEncoderConfig,
|
||||
has_text_modality=False,
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=2,
|
||||
)
|
||||
self.mask_decoder_config_tester = ConfigTester(
|
||||
self, config_class=SamMaskDecoderConfig, has_text_modality=False
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.vision_config_tester.run_common_tests()
|
||||
self.prompt_encoder_config_tester.run_common_tests()
|
||||
self.mask_decoder_config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_get_image_features(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_get_image_features(*config_and_inputs)
|
||||
|
||||
def test_image_hidden_states(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
expected_vision_attention_shape = (
|
||||
self.model_tester.batch_size * self.model_tester.num_attention_heads,
|
||||
196,
|
||||
196,
|
||||
)
|
||||
expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
vision_attentions = outputs.vision_attentions
|
||||
self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
mask_decoder_attentions = outputs.mask_decoder_attentions
|
||||
self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
vision_attentions = outputs.vision_attentions
|
||||
self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
mask_decoder_attentions = outputs.mask_decoder_attentions
|
||||
self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(vision_attentions[0].shape[-4:]),
|
||||
list(expected_vision_attention_shape),
|
||||
)
|
||||
|
||||
self.assertListEqual(
|
||||
list(mask_decoder_attentions[0].shape[-4:]),
|
||||
list(expected_mask_decoder_attention_shape),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFSamModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-4, name="outputs", attributes=None):
|
||||
super().check_pt_tf_outputs(
|
||||
tf_outputs=tf_outputs,
|
||||
pt_outputs=pt_outputs,
|
||||
model_class=model_class,
|
||||
tol=tol,
|
||||
name=name,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
|
||||
def prepare_image():
|
||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
return raw_image
|
||||
|
||||
|
||||
def prepare_dog_img():
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
return raw_image
|
||||
|
||||
|
||||
@slow
|
||||
class SamModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_mask_generation_no_point(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
raw_image = prepare_image()
|
||||
inputs = processor(images=raw_image, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores = tf.squeeze(outputs.iou_scores)
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
|
||||
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=2e-4))
|
||||
self.assertTrue(np.allclose(masks.numpy(), np.array([-6.6381, -6.0734, -7.5308]), atol=1e-2))
|
||||
|
||||
def test_inference_mask_generation_one_point_one_bb(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
raw_image = prepare_image()
|
||||
input_boxes = [[[650, 900, 1000, 1250]]]
|
||||
input_points = [[[820, 1080]]]
|
||||
|
||||
inputs = processor(images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores = tf.squeeze(outputs.iou_scores)
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
|
||||
self.assertTrue(np.allclose(scores[-1], np.array(0.9935), atol=2e-4))
|
||||
self.assertTrue(np.allclose(masks.numpy(), np.array([-21.5465, -23.1122, -22.3331]), atol=2e-2))
|
||||
|
||||
def test_inference_mask_generation_batched_points_batched_images(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
raw_image = prepare_image()
|
||||
input_points = [
|
||||
[[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
|
||||
[[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
|
||||
]
|
||||
|
||||
inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores = tf.squeeze(outputs.iou_scores)
|
||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||
|
||||
EXPECTED_SCORES = np.array(
|
||||
[
|
||||
[
|
||||
[0.9673, 0.9441, 0.9084],
|
||||
[0.9673, 0.9441, 0.9084],
|
||||
[0.9673, 0.9441, 0.9084],
|
||||
[0.9673, 0.9441, 0.9084],
|
||||
],
|
||||
[
|
||||
[0.8405, 0.6292, 0.3840],
|
||||
[0.9673, 0.9441, 0.9084],
|
||||
[0.9673, 0.9441, 0.9084],
|
||||
[0.9673, 0.9441, 0.9084],
|
||||
],
|
||||
]
|
||||
)
|
||||
EXPECTED_MASKS = np.array([-26.5424, -34.0901, -30.6406])
|
||||
self.assertTrue(np.allclose(scores.numpy(), EXPECTED_SCORES, atol=1e-3))
|
||||
self.assertTrue(np.allclose(masks.numpy(), EXPECTED_MASKS, atol=3e-2))
|
||||
|
||||
def test_inference_mask_generation_one_point_one_bb_zero(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
raw_image = prepare_image()
|
||||
input_boxes = [[[620, 900, 1000, 1255]]]
|
||||
input_points = [[[820, 1080]]]
|
||||
labels = [[0]]
|
||||
|
||||
inputs = processor(
|
||||
images=raw_image,
|
||||
input_boxes=input_boxes,
|
||||
input_points=input_points,
|
||||
input_labels=labels,
|
||||
return_tensors="tf",
|
||||
)
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores = tf.squeeze(outputs.iou_scores)
|
||||
|
||||
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9689), atol=1e-4))
|
||||
|
||||
def test_inference_mask_generation_one_point(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
raw_image = prepare_image()
|
||||
|
||||
input_points = [[[400, 650]]]
|
||||
input_labels = [[1]]
|
||||
|
||||
inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores = tf.squeeze(outputs.iou_scores)
|
||||
|
||||
self.assertTrue(np.allclose(scores[-1], np.array(0.9712), atol=1e-4))
|
||||
|
||||
# With no label
|
||||
input_points = [[[400, 650]]]
|
||||
|
||||
inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores = tf.squeeze(outputs.iou_scores)
|
||||
|
||||
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9712), atol=1e-4))
|
||||
|
||||
def test_inference_mask_generation_two_points(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
raw_image = prepare_image()
|
||||
|
||||
input_points = [[[400, 650], [800, 650]]]
|
||||
input_labels = [[1, 1]]
|
||||
|
||||
inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores = tf.squeeze(outputs.iou_scores)
|
||||
|
||||
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4))
|
||||
|
||||
# no labels
|
||||
inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores = tf.squeeze(outputs.iou_scores)
|
||||
|
||||
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4))
|
||||
|
||||
def test_inference_mask_generation_two_points_batched(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
raw_image = prepare_image()
|
||||
|
||||
input_points = [[[400, 650], [800, 650]], [[400, 650]]]
|
||||
input_labels = [[1, 1], [1]]
|
||||
|
||||
inputs = processor(
|
||||
images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="tf"
|
||||
)
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores = tf.squeeze(outputs.iou_scores)
|
||||
|
||||
self.assertTrue(np.allclose(scores[0][-1].numpy(), np.array(0.9936), atol=1e-4))
|
||||
self.assertTrue(np.allclose(scores[1][-1], np.array(0.9716), atol=1e-4))
|
||||
|
||||
def test_inference_mask_generation_one_box(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
raw_image = prepare_image()
|
||||
|
||||
input_boxes = [[[75, 275, 1725, 850]]]
|
||||
|
||||
inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores = tf.squeeze(outputs.iou_scores)
|
||||
|
||||
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.8686), atol=1e-4))
|
||||
|
||||
def test_inference_mask_generation_batched_image_one_point(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
raw_image = prepare_image()
|
||||
raw_dog_image = prepare_dog_img()
|
||||
|
||||
input_points = [[[820, 1080]], [[220, 470]]]
|
||||
|
||||
inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores_batched = tf.squeeze(outputs.iou_scores)
|
||||
|
||||
input_points = [[[220, 470]]]
|
||||
|
||||
inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
scores_single = tf.squeeze(outputs.iou_scores)
|
||||
self.assertTrue(np.allclose(scores_batched[1, :].numpy(), scores_single.numpy(), atol=1e-4))
|
||||
|
||||
def test_inference_mask_generation_two_points_point_batch(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
raw_image = prepare_image()
|
||||
|
||||
# fmt: off
|
||||
input_points = tf.convert_to_tensor([[[400, 650]], [[220, 470]]])
|
||||
# fmt: on
|
||||
|
||||
input_points = tf.expand_dims(input_points, 0)
|
||||
|
||||
inputs = processor(raw_image, input_points=input_points, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
iou_scores = outputs.iou_scores
|
||||
self.assertTrue(iou_scores.shape == (1, 2, 3))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
iou_scores.numpy(),
|
||||
np.array([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_three_boxes_point_batch(self):
|
||||
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
raw_image = prepare_image()
|
||||
|
||||
# fmt: off
|
||||
input_boxes = tf.convert_to_tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]])
|
||||
EXPECTED_IOU = np.array([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]])
|
||||
# fmt: on
|
||||
input_boxes = tf.expand_dims(input_boxes, 0)
|
||||
|
||||
inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="tf")
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
iou_scores = outputs.iou_scores
|
||||
self.assertTrue(iou_scores.shape == (1, 3, 3))
|
||||
self.assertTrue(np.allclose(iou_scores.numpy(), EXPECTED_IOU, atol=1e-4, rtol=1e-4))
|
||||
@@ -17,8 +17,14 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_torchvision, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import (
|
||||
is_pt_tf_cross_test,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torchvision,
|
||||
require_vision,
|
||||
)
|
||||
from transformers.utils import is_tf_available, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@@ -29,6 +35,9 @@ if is_vision_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torchvision
|
||||
@@ -110,3 +119,158 @@ class SamProcessorTest(unittest.TestCase):
|
||||
dummy_masks = [[1, 0], [0, 1]]
|
||||
with self.assertRaises(ValueError):
|
||||
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_tf
|
||||
class TFSamProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = SamImageProcessor()
|
||||
processor = SamProcessor(image_processor)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
"""
|
||||
|
||||
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||
|
||||
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||
|
||||
return image_inputs
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = SamProcessor(image_processor=self.get_image_processor())
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
|
||||
|
||||
processor = SamProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0)
|
||||
|
||||
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.image_processor, SamImageProcessor)
|
||||
|
||||
def test_image_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
processor = SamProcessor(image_processor=image_processor)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
input_feat_extract = image_processor(image_input, return_tensors="np")
|
||||
input_processor = processor(images=image_input, return_tensors="np")
|
||||
|
||||
input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor
|
||||
input_feat_extract.pop("reshaped_input_sizes") # pop reshaped_input_sizes as it is popped in the processor
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
@require_tf
|
||||
def test_post_process_masks(self):
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
processor = SamProcessor(image_processor=image_processor)
|
||||
dummy_masks = [tf.ones((1, 3, 5, 5))]
|
||||
|
||||
original_sizes = [[1764, 2646]]
|
||||
|
||||
reshaped_input_size = [[683, 1024]]
|
||||
masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf")
|
||||
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
|
||||
|
||||
masks = processor.post_process_masks(
|
||||
dummy_masks,
|
||||
tf.convert_to_tensor(original_sizes),
|
||||
tf.convert_to_tensor(reshaped_input_size),
|
||||
return_tensors="tf",
|
||||
)
|
||||
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
|
||||
|
||||
# should also work with np
|
||||
dummy_masks = [np.ones((1, 3, 5, 5))]
|
||||
masks = processor.post_process_masks(
|
||||
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
|
||||
)
|
||||
|
||||
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
|
||||
|
||||
dummy_masks = [[1, 0], [0, 1]]
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
masks = processor.post_process_masks(
|
||||
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
|
||||
)
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torchvision
|
||||
class SamProcessorEquivalenceTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = SamImageProcessor()
|
||||
processor = SamProcessor(image_processor)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
"""
|
||||
|
||||
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||
|
||||
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||
|
||||
return image_inputs
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_post_process_masks_equivalence(self):
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
processor = SamProcessor(image_processor=image_processor)
|
||||
dummy_masks = np.random.randint(0, 2, size=(1, 3, 5, 5)).astype(np.float32)
|
||||
tf_dummy_masks = [tf.convert_to_tensor(dummy_masks)]
|
||||
pt_dummy_masks = [torch.tensor(dummy_masks)]
|
||||
|
||||
original_sizes = [[1764, 2646]]
|
||||
|
||||
reshaped_input_size = [[683, 1024]]
|
||||
tf_masks = processor.post_process_masks(
|
||||
tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf"
|
||||
)
|
||||
pt_masks = processor.post_process_masks(
|
||||
pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt"
|
||||
)
|
||||
|
||||
self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy()))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_image_processor_equivalence(self):
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
processor = SamProcessor(image_processor=image_processor)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
pt_input_feat_extract = image_processor(image_input, return_tensors="pt")["pixel_values"].numpy()
|
||||
pt_input_processor = processor(images=image_input, return_tensors="pt")["pixel_values"].numpy()
|
||||
|
||||
tf_input_feat_extract = image_processor(image_input, return_tensors="tf")["pixel_values"].numpy()
|
||||
tf_input_processor = processor(images=image_input, return_tensors="tf")["pixel_values"].numpy()
|
||||
|
||||
self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor))
|
||||
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract))
|
||||
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor))
|
||||
|
||||
Reference in New Issue
Block a user