TF port of the Segment Anything Model (SAM) (#22970)
* First commit * Add auto-translation with GPT-4 * make fixup * Add a functional layernorm for TF * Add all the auxiliary imports etc. * Add the extra processor and tests * rebase to main * Add all the needed fixes to the GPT code * make fixup * Make convolutions channels-last so they run on CPU * make fixup * Fix final issues * Fix other models affected by test change * Clarify comment on the sparse_prompt_embeddings check * Refactor functional_layernorm, use shape_list in place of .shape in some places * Remove deprecated torch-alike code * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Refactor processor with common methods and separated private methods * make fixup * Quietly delete the file that didn't do anything (sorry Sylvain) * Refactor the processor tests into one file * make fixup * Clean up some unnecessary indirection * Fix TF mask postprocessing * Add more processor equivalence tests * Refactor generate_crop_boxes to use framework-neutral np code * Make the serving output correctly conditional * Fix error message line length * Use dict keys rather than indices internally in both TF and PT SAM call/forward * Return dicts internally in the call/forward methods * Revert changes to common tests and just override check_pt_tf_outputs * Revert changes to other model tests * Clarify comments for functional layernorm * Add missing transpose from PT code * Removed unused copied from in PT code * Remove overrides for tests that don't exist in TF * Fix transpose and update tests for PT and TF to check pred_masks * Add training flag * Update tests to use TF checkpoints * Update index.mdx * Add missing cross-test decorator * Remove optional extra asterisks * Revert return_dict changes in PT code * Update src/transformers/models/sam/modeling_tf_sam.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove None return annotations on init methods * Update tests/models/sam/test_processor_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix input_boxes shapes * make fixup --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| RWKV | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| RWKV | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| SAM | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| SAM | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
|||||||
@@ -99,3 +99,9 @@ Resources:
|
|||||||
|
|
||||||
[[autodoc]] SamModel
|
[[autodoc]] SamModel
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
|
||||||
|
## TFSamModel
|
||||||
|
|
||||||
|
[[autodoc]] TFSamModel
|
||||||
|
- call
|
||||||
@@ -3406,6 +3406,13 @@ else:
|
|||||||
"TFRoFormerPreTrainedModel",
|
"TFRoFormerPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.sam"].extend(
|
||||||
|
[
|
||||||
|
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"TFSamModel",
|
||||||
|
"TFSamPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.segformer"].extend(
|
_import_structure["models.segformer"].extend(
|
||||||
[
|
[
|
||||||
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -6657,6 +6664,11 @@ if TYPE_CHECKING:
|
|||||||
TFRoFormerModel,
|
TFRoFormerModel,
|
||||||
TFRoFormerPreTrainedModel,
|
TFRoFormerPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.sam import (
|
||||||
|
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFSamModel,
|
||||||
|
TFSamPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.segformer import (
|
from .models.segformer import (
|
||||||
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFSegformerDecodeHead,
|
TFSegformerDecodeHead,
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("roberta", "TFRobertaModel"),
|
("roberta", "TFRobertaModel"),
|
||||||
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
||||||
("roformer", "TFRoFormerModel"),
|
("roformer", "TFRoFormerModel"),
|
||||||
|
("sam", "TFSamModel"),
|
||||||
("segformer", "TFSegformerModel"),
|
("segformer", "TFSegformerModel"),
|
||||||
("speech_to_text", "TFSpeech2TextModel"),
|
("speech_to_text", "TFSpeech2TextModel"),
|
||||||
("swin", "TFSwinModel"),
|
("swin", "TFSwinModel"),
|
||||||
@@ -426,6 +427,11 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
|||||||
("mobilebert", "TFMobileBertForNextSentencePrediction"),
|
("mobilebert", "TFMobileBertForNextSentencePrediction"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
("sam", "TFSamModel"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
|
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
|
||||||
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||||
@@ -476,6 +482,14 @@ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|||||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
|
||||||
|
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
class TFAutoModel(_BaseAutoModelClass):
|
class TFAutoModel(_BaseAutoModelClass):
|
||||||
_model_mapping = TF_MODEL_MAPPING
|
_model_mapping = TF_MODEL_MAPPING
|
||||||
|
|||||||
@@ -13,7 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
from ...utils import (
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@@ -39,6 +45,17 @@ else:
|
|||||||
"SamModel",
|
"SamModel",
|
||||||
"SamPreTrainedModel",
|
"SamPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
try:
|
||||||
|
if not is_tf_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_import_structure["modeling_tf_sam"] = [
|
||||||
|
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"TFSamModel",
|
||||||
|
"TFSamPreTrainedModel",
|
||||||
|
]
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
@@ -66,6 +83,14 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel
|
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_tf_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
|
|||||||
@@ -34,7 +34,14 @@ from ...image_utils import (
|
|||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
)
|
)
|
||||||
from ...utils import TensorType, is_torch_available, is_torchvision_available, logging, requires_backends
|
from ...utils import (
|
||||||
|
TensorType,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
is_torchvision_available,
|
||||||
|
logging,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -44,6 +51,12 @@ if is_torch_available():
|
|||||||
if is_torchvision_available():
|
if is_torchvision_available():
|
||||||
from torchvision.ops.boxes import batched_nms
|
from torchvision.ops.boxes import batched_nms
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.experimental import numpy as tnp
|
||||||
|
|
||||||
|
from ...tf_utils import flatten, shape_list
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -372,6 +385,61 @@ class SamImageProcessor(BaseImageProcessor):
|
|||||||
return encoded_outputs
|
return encoded_outputs
|
||||||
|
|
||||||
def post_process_masks(
|
def post_process_masks(
|
||||||
|
self,
|
||||||
|
masks,
|
||||||
|
original_sizes,
|
||||||
|
reshaped_input_sizes,
|
||||||
|
mask_threshold=0.0,
|
||||||
|
binarize=True,
|
||||||
|
pad_size=None,
|
||||||
|
return_tensors="pt",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Remove padding and upscale masks to the original image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`):
|
||||||
|
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
||||||
|
original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
|
||||||
|
The original sizes of each image before it was resized to the model's expected input shape, in (height,
|
||||||
|
width) format.
|
||||||
|
reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
|
||||||
|
The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
||||||
|
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||||||
|
The threshold to use for binarizing the masks.
|
||||||
|
binarize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to binarize the masks.
|
||||||
|
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
||||||
|
The target size the images were padded to before being passed to the model. If None, the target size is
|
||||||
|
assumed to be the processor's `pad_size`.
|
||||||
|
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||||
|
If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
|
||||||
|
Returns:
|
||||||
|
(`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
|
||||||
|
(height, width) is given by original_size.
|
||||||
|
"""
|
||||||
|
if return_tensors == "pt":
|
||||||
|
return self._post_process_masks_pt(
|
||||||
|
masks=masks,
|
||||||
|
original_sizes=original_sizes,
|
||||||
|
reshaped_input_sizes=reshaped_input_sizes,
|
||||||
|
mask_threshold=mask_threshold,
|
||||||
|
binarize=binarize,
|
||||||
|
pad_size=pad_size,
|
||||||
|
)
|
||||||
|
elif return_tensors == "tf":
|
||||||
|
return self._post_process_masks_tf(
|
||||||
|
masks=masks,
|
||||||
|
original_sizes=original_sizes,
|
||||||
|
reshaped_input_sizes=reshaped_input_sizes,
|
||||||
|
mask_threshold=mask_threshold,
|
||||||
|
binarize=binarize,
|
||||||
|
pad_size=pad_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("return_tensors must be either 'pt' or 'tf'")
|
||||||
|
|
||||||
|
def _post_process_masks_pt(
|
||||||
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -418,21 +486,70 @@ class SamImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
return output_masks
|
return output_masks
|
||||||
|
|
||||||
def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh):
|
def _post_process_masks_tf(
|
||||||
|
self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Remove padding and upscale masks to the original image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masks (`tf.Tensor`):
|
||||||
|
Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
||||||
|
original_sizes (`tf.Tensor`):
|
||||||
|
The original size of the images before resizing for input to the model, in (height, width) format.
|
||||||
|
reshaped_input_sizes (`tf.Tensor`):
|
||||||
|
The size of the image input to the model, in (height, width) format. Used to remove padding.
|
||||||
|
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||||||
|
The threshold to use for binarizing the masks.
|
||||||
|
binarize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to binarize the masks.
|
||||||
|
pad_size (`int`, *optional*, defaults to `self.pad_size`):
|
||||||
|
The target size the images were padded to before being passed to the model. If None, the target size is
|
||||||
|
assumed to be the processor's `pad_size`.
|
||||||
|
Returns:
|
||||||
|
(`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
|
||||||
|
given by original_size.
|
||||||
|
"""
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
pad_size = self.pad_size if pad_size is None else pad_size
|
||||||
|
target_image_size = (pad_size["height"], pad_size["width"])
|
||||||
|
|
||||||
|
output_masks = []
|
||||||
|
for i, original_size in enumerate(original_sizes):
|
||||||
|
# tf.image expects NHWC, we transpose the NCHW inputs for it
|
||||||
|
mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
|
||||||
|
interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
|
||||||
|
interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
|
||||||
|
interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
|
||||||
|
if binarize:
|
||||||
|
interpolated_mask = interpolated_mask > mask_threshold
|
||||||
|
# And then we transpose them back at the end
|
||||||
|
output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))
|
||||||
|
|
||||||
|
return output_masks
|
||||||
|
|
||||||
|
def post_process_for_mask_generation(
|
||||||
|
self, all_masks, all_scores, all_boxes, crops_nms_thresh, return_tensors="pt"
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
|
Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
all_masks (`List[torch.Tensor]`):
|
all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`):
|
||||||
List of all predicted segmentation masks
|
List of all predicted segmentation masks
|
||||||
all_scores (`List[torch.Tensor]`):
|
all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`):
|
||||||
List of all predicted iou scores
|
List of all predicted iou scores
|
||||||
all_boxes (`List[torch.Tensor]`):
|
all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`):
|
||||||
List of all bounding boxes of the predicted masks
|
List of all bounding boxes of the predicted masks
|
||||||
crops_nms_thresh (`float`):
|
crops_nms_thresh (`float`):
|
||||||
Threshold for NMS (Non Maximum Suppression) algorithm.
|
Threshold for NMS (Non Maximum Suppression) algorithm.
|
||||||
|
return_tensors (`str`, *optional*, defaults to `pt`):
|
||||||
|
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
||||||
"""
|
"""
|
||||||
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
if return_tensors == "pt":
|
||||||
|
return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
||||||
|
elif return_tensors == "tf":
|
||||||
|
return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)
|
||||||
|
|
||||||
def generate_crop_boxes(
|
def generate_crop_boxes(
|
||||||
self,
|
self,
|
||||||
@@ -443,6 +560,7 @@ class SamImageProcessor(BaseImageProcessor):
|
|||||||
points_per_crop: Optional[int] = 32,
|
points_per_crop: Optional[int] = 32,
|
||||||
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
||||||
device: Optional["torch.device"] = None,
|
device: Optional["torch.device"] = None,
|
||||||
|
return_tensors: str = "pt",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||||
@@ -464,10 +582,35 @@ class SamImageProcessor(BaseImageProcessor):
|
|||||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||||
device (`torch.device`, *optional*, defaults to None):
|
device (`torch.device`, *optional*, defaults to None):
|
||||||
Device to use for the computation. If None, cpu will be used.
|
Device to use for the computation. If None, cpu will be used.
|
||||||
|
return_tensors (`str`, *optional*, defaults to `pt`):
|
||||||
|
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
||||||
"""
|
"""
|
||||||
return _generate_crop_boxes(
|
crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
|
||||||
image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device
|
image,
|
||||||
|
target_size,
|
||||||
|
crop_n_layers,
|
||||||
|
overlap_ratio,
|
||||||
|
points_per_crop,
|
||||||
|
crop_n_points_downscale_factor,
|
||||||
)
|
)
|
||||||
|
if return_tensors == "pt":
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
crop_boxes = torch.tensor(crop_boxes, device=device)
|
||||||
|
points_per_crop = torch.tensor(points_per_crop, device=device)
|
||||||
|
# cropped_images stays as np
|
||||||
|
input_labels = torch.tensor(input_labels, device=device)
|
||||||
|
|
||||||
|
elif return_tensors == "tf":
|
||||||
|
if device is not None:
|
||||||
|
raise ValueError("device is not a supported argument when return_tensors is tf!")
|
||||||
|
crop_boxes = tf.convert_to_tensor(crop_boxes)
|
||||||
|
points_per_crop = tf.convert_to_tensor(points_per_crop)
|
||||||
|
# cropped_images stays as np
|
||||||
|
input_labels = tf.convert_to_tensor(input_labels)
|
||||||
|
else:
|
||||||
|
raise ValueError("return_tensors must be either 'pt' or 'tf'.")
|
||||||
|
return crop_boxes, points_per_crop, cropped_images, input_labels
|
||||||
|
|
||||||
def filter_masks(
|
def filter_masks(
|
||||||
self,
|
self,
|
||||||
@@ -479,6 +622,67 @@ class SamImageProcessor(BaseImageProcessor):
|
|||||||
stability_score_thresh=0.95,
|
stability_score_thresh=0.95,
|
||||||
mask_threshold=0,
|
mask_threshold=0,
|
||||||
stability_score_offset=1,
|
stability_score_offset=1,
|
||||||
|
return_tensors="pt",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
||||||
|
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
||||||
|
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
||||||
|
bounding boxes and pad the predicted masks if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masks (`Union[torch.Tensor, tf.Tensor]`):
|
||||||
|
Input masks.
|
||||||
|
iou_scores (`Union[torch.Tensor, tf.Tensor]`):
|
||||||
|
List of IoU scores.
|
||||||
|
original_size (`Tuple[int,int]`):
|
||||||
|
Size of the orginal image.
|
||||||
|
cropped_box_image (`np.array`):
|
||||||
|
The cropped image.
|
||||||
|
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
||||||
|
The threshold for the iou scores.
|
||||||
|
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
||||||
|
The threshold for the stability score.
|
||||||
|
mask_threshold (`float`, *optional*, defaults to 0):
|
||||||
|
The threshold for the predicted masks.
|
||||||
|
stability_score_offset (`float`, *optional*, defaults to 1):
|
||||||
|
The offset for the stability score used in the `_compute_stability_score` method.
|
||||||
|
return_tensors (`str`, *optional*, defaults to `pt`):
|
||||||
|
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
||||||
|
"""
|
||||||
|
if return_tensors == "pt":
|
||||||
|
return self._filter_masks_pt(
|
||||||
|
masks=masks,
|
||||||
|
iou_scores=iou_scores,
|
||||||
|
original_size=original_size,
|
||||||
|
cropped_box_image=cropped_box_image,
|
||||||
|
pred_iou_thresh=pred_iou_thresh,
|
||||||
|
stability_score_thresh=stability_score_thresh,
|
||||||
|
mask_threshold=mask_threshold,
|
||||||
|
stability_score_offset=stability_score_offset,
|
||||||
|
)
|
||||||
|
elif return_tensors == "tf":
|
||||||
|
return self._filter_masks_tf(
|
||||||
|
masks=masks,
|
||||||
|
iou_scores=iou_scores,
|
||||||
|
original_size=original_size,
|
||||||
|
cropped_box_image=cropped_box_image,
|
||||||
|
pred_iou_thresh=pred_iou_thresh,
|
||||||
|
stability_score_thresh=stability_score_thresh,
|
||||||
|
mask_threshold=mask_threshold,
|
||||||
|
stability_score_offset=stability_score_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _filter_masks_pt(
|
||||||
|
self,
|
||||||
|
masks,
|
||||||
|
iou_scores,
|
||||||
|
original_size,
|
||||||
|
cropped_box_image,
|
||||||
|
pred_iou_thresh=0.88,
|
||||||
|
stability_score_thresh=0.95,
|
||||||
|
mask_threshold=0,
|
||||||
|
stability_score_offset=1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
||||||
@@ -525,7 +729,7 @@ class SamImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
# compute stability score
|
# compute stability score
|
||||||
if stability_score_thresh > 0.0:
|
if stability_score_thresh > 0.0:
|
||||||
stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset)
|
stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
|
||||||
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
||||||
|
|
||||||
scores = iou_scores[keep_mask]
|
scores = iou_scores[keep_mask]
|
||||||
@@ -549,8 +753,85 @@ class SamImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
return masks, scores, converted_boxes
|
return masks, scores, converted_boxes
|
||||||
|
|
||||||
|
def _filter_masks_tf(
|
||||||
|
self,
|
||||||
|
masks,
|
||||||
|
iou_scores,
|
||||||
|
original_size,
|
||||||
|
cropped_box_image,
|
||||||
|
pred_iou_thresh=0.88,
|
||||||
|
stability_score_thresh=0.95,
|
||||||
|
mask_threshold=0,
|
||||||
|
stability_score_offset=1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
|
||||||
|
that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
|
||||||
|
score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
|
||||||
|
bounding boxes and pad the predicted masks if necessary.
|
||||||
|
|
||||||
def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
|
Args:
|
||||||
|
masks (`tf.Tensor`):
|
||||||
|
Input masks.
|
||||||
|
iou_scores (`tf.Tensor`):
|
||||||
|
List of IoU scores.
|
||||||
|
original_size (`Tuple[int,int]`):
|
||||||
|
Size of the orginal image.
|
||||||
|
cropped_box_image (`np.array`):
|
||||||
|
The cropped image.
|
||||||
|
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
||||||
|
The threshold for the iou scores.
|
||||||
|
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
||||||
|
The threshold for the stability score.
|
||||||
|
mask_threshold (`float`, *optional*, defaults to 0):
|
||||||
|
The threshold for the predicted masks.
|
||||||
|
stability_score_offset (`float`, *optional*, defaults to 1):
|
||||||
|
The offset for the stability score used in the `_compute_stability_score` method.
|
||||||
|
|
||||||
|
"""
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
original_height, original_width = original_size
|
||||||
|
iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
|
||||||
|
masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])
|
||||||
|
|
||||||
|
if masks.shape[0] != iou_scores.shape[0]:
|
||||||
|
raise ValueError("masks and iou_scores must have the same batch size.")
|
||||||
|
|
||||||
|
batch_size = masks.shape[0]
|
||||||
|
|
||||||
|
keep_mask = tf.ones(batch_size, dtype=tf.bool)
|
||||||
|
|
||||||
|
if pred_iou_thresh > 0.0:
|
||||||
|
keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
|
||||||
|
|
||||||
|
# compute stability score
|
||||||
|
if stability_score_thresh > 0.0:
|
||||||
|
stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
|
||||||
|
keep_mask = keep_mask & (stability_scores > stability_score_thresh)
|
||||||
|
|
||||||
|
scores = iou_scores[keep_mask]
|
||||||
|
masks = masks[keep_mask]
|
||||||
|
|
||||||
|
# binarize masks
|
||||||
|
masks = masks > mask_threshold
|
||||||
|
converted_boxes = _batched_mask_to_box_tf(masks)
|
||||||
|
|
||||||
|
keep_mask = ~_is_box_near_crop_edge_tf(
|
||||||
|
converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = scores[keep_mask]
|
||||||
|
masks = masks[keep_mask]
|
||||||
|
converted_boxes = converted_boxes[keep_mask]
|
||||||
|
|
||||||
|
masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)
|
||||||
|
# conversion to rle is necessary to run non-maximum suppresion
|
||||||
|
masks = _mask_to_rle_tf(masks)
|
||||||
|
|
||||||
|
return masks, scores, converted_boxes
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
|
||||||
# One mask is always contained inside the other.
|
# One mask is always contained inside the other.
|
||||||
# Save memory by preventing unnecesary cast to torch.int64
|
# Save memory by preventing unnecesary cast to torch.int64
|
||||||
intersections = (
|
intersections = (
|
||||||
@@ -561,6 +842,17 @@ def _compute_stability_score(masks: "torch.Tensor", mask_threshold: float, stabi
|
|||||||
return stability_scores
|
return stability_scores
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
|
||||||
|
# Torch does Py3-style division but TF does floor division with ints. We cast to float32 in TF to make sure
|
||||||
|
# we get the right division results.
|
||||||
|
intersections = tf.count_nonzero(
|
||||||
|
masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
|
||||||
|
)
|
||||||
|
unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
|
||||||
|
stability_scores = intersections / unions
|
||||||
|
return stability_scores
|
||||||
|
|
||||||
|
|
||||||
def _build_point_grid(n_per_side: int) -> np.ndarray:
|
def _build_point_grid(n_per_side: int) -> np.ndarray:
|
||||||
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||||
offset = 1 / (2 * n_per_side)
|
offset = 1 / (2 * n_per_side)
|
||||||
@@ -606,7 +898,6 @@ def _generate_crop_boxes(
|
|||||||
overlap_ratio: float = 512 / 1500,
|
overlap_ratio: float = 512 / 1500,
|
||||||
points_per_crop: Optional[int] = 32,
|
points_per_crop: Optional[int] = 32,
|
||||||
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
||||||
device: Optional["torch.device"] = None,
|
|
||||||
) -> Tuple[List[List[int]], List[int]]:
|
) -> Tuple[List[List[int]], List[int]]:
|
||||||
"""
|
"""
|
||||||
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||||
@@ -626,11 +917,7 @@ def _generate_crop_boxes(
|
|||||||
Number of points to sample per crop.
|
Number of points to sample per crop.
|
||||||
crop_n_points_downscale_factor (`int`, *optional*):
|
crop_n_points_downscale_factor (`int`, *optional*):
|
||||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||||
device (`torch.device`, *optional*):
|
|
||||||
Device to run the crop generation on. Defaults to CPU.
|
|
||||||
"""
|
"""
|
||||||
if device is None:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
if isinstance(image, list):
|
if isinstance(image, list):
|
||||||
raise ValueError("Only one image is allowed for crop generation.")
|
raise ValueError("Only one image is allowed for crop generation.")
|
||||||
@@ -648,12 +935,11 @@ def _generate_crop_boxes(
|
|||||||
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
||||||
)
|
)
|
||||||
|
|
||||||
crop_boxes = torch.tensor(crop_boxes, dtype=torch.float32, device=device)
|
crop_boxes = crop_boxes.astype(np.float32)
|
||||||
point_grid_per_crop = np.array([point_grid_per_crop])
|
points_per_crop = np.array([point_grid_per_crop])
|
||||||
points_per_crop = torch.tensor(point_grid_per_crop, device=device)
|
points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
|
||||||
points_per_crop = points_per_crop.permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
input_labels = torch.ones_like(points_per_crop[:, :, :, 0], dtype=torch.long, device=device)
|
input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
|
||||||
|
|
||||||
return crop_boxes, points_per_crop, cropped_images, input_labels
|
return crop_boxes, points_per_crop, cropped_images, input_labels
|
||||||
|
|
||||||
@@ -730,6 +1016,16 @@ def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
|
|||||||
return torch.nn.functional.pad(masks, pad, value=0)
|
return torch.nn.functional.pad(masks, pad, value=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int):
|
||||||
|
left, top, right, bottom = crop_box
|
||||||
|
if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
|
||||||
|
return masks
|
||||||
|
# Coordinate transform masks
|
||||||
|
pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
|
||||||
|
pad = (left, pad_x - left, top, pad_y - top)
|
||||||
|
return tf.pad(masks, pad, constant_values=0)
|
||||||
|
|
||||||
|
|
||||||
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
||||||
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||||
@@ -748,6 +1044,24 @@ def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
|
|||||||
return torch.any(near_crop_edge, dim=1)
|
return torch.any(near_crop_edge, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
|
||||||
|
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||||
|
crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
|
||||||
|
orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)
|
||||||
|
|
||||||
|
left, top, _, _ = crop_box
|
||||||
|
offset = tf.convert_to_tensor([[left, top, left, top]])
|
||||||
|
# Check if boxes has a channel dimension
|
||||||
|
if len(boxes.shape) == 3:
|
||||||
|
offset = tf.expand_dims(offset, 1)
|
||||||
|
boxes = tf.cast(boxes + offset, tf.float32)
|
||||||
|
|
||||||
|
near_crop_edge = tnp.isclose(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
|
||||||
|
near_image_edge = tnp.isclose(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
|
||||||
|
near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
|
||||||
|
return tf.reduce_any(near_crop_edge, axis=1)
|
||||||
|
|
||||||
|
|
||||||
def _batched_mask_to_box(masks: "torch.Tensor"):
|
def _batched_mask_to_box(masks: "torch.Tensor"):
|
||||||
"""
|
"""
|
||||||
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
||||||
@@ -797,6 +1111,54 @@ def _batched_mask_to_box(masks: "torch.Tensor"):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _batched_mask_to_box_tf(masks: "tf.Tensor"):
|
||||||
|
"""
|
||||||
|
Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
|
||||||
|
corresponds the following required indices:
|
||||||
|
- LEFT: left hand side of the bounding box
|
||||||
|
- TOP: top of the bounding box
|
||||||
|
- RIGHT: right of the bounding box
|
||||||
|
- BOTTOM: bottom of the bounding box
|
||||||
|
|
||||||
|
Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
|
||||||
|
is channel_1 x channel_2 x ... x 4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- masks (`tf.Tensor` of shape `(batch, nb_mask, height, width)`)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if tf.size(masks) == 0:
|
||||||
|
return tf.zeros([*masks.shape[:-2], 4])
|
||||||
|
|
||||||
|
# Normalize shape to Cxheightxwidth
|
||||||
|
shape = shape_list(masks)
|
||||||
|
height, width = shape[-2:]
|
||||||
|
|
||||||
|
# Get top and bottom edges
|
||||||
|
in_height = tf.reduce_max(masks, axis=-1)
|
||||||
|
in_height_coords = in_height * tf.range(height)[None, :]
|
||||||
|
bottom_edges = tf.reduce_max(in_height_coords, axis=-1)
|
||||||
|
in_height_coords = in_height_coords + height * (~in_height)
|
||||||
|
top_edges = tf.reduce_min(in_height_coords, axis=-1)
|
||||||
|
|
||||||
|
# Get left and right edges
|
||||||
|
in_width, _ = tf.reduce_max(masks, axis=-2)
|
||||||
|
in_width_coords = in_width * tf.range(width)[None, :]
|
||||||
|
right_edges, _ = tf.reduce_max(in_width_coords, axis=-1)
|
||||||
|
in_width_coords = in_width_coords + width * (~in_width)
|
||||||
|
left_edges, _ = tf.reduce_min(in_width_coords, axis=-1)
|
||||||
|
|
||||||
|
# If the mask is empty the right edge will be to the left of the left edge.
|
||||||
|
# Replace these boxes with [0, 0, 0, 0]
|
||||||
|
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||||
|
out = tf.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
|
||||||
|
out = out * tf.expand_dims(~empty_filter, -1)
|
||||||
|
|
||||||
|
# Return to original shape
|
||||||
|
out = tf.reshape(out, *shape[:-2], 4)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
||||||
"""
|
"""
|
||||||
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
||||||
@@ -820,6 +1182,29 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_to_rle_tf(input_mask: "tf.Tensor"):
|
||||||
|
"""
|
||||||
|
Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools.
|
||||||
|
"""
|
||||||
|
# Put in fortran order and flatten height and width
|
||||||
|
batch_size, height, width = input_mask.shape
|
||||||
|
input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)
|
||||||
|
|
||||||
|
# Compute change indices
|
||||||
|
diff = input_mask[:, 1:] ^ input_mask[:, :-1]
|
||||||
|
change_indices = tf.where(diff)
|
||||||
|
|
||||||
|
# Encode run length
|
||||||
|
out = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
|
||||||
|
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||||
|
counts = [] if input_mask[i, 0] == 0 else [0]
|
||||||
|
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
|
||||||
|
out.append({"size": [height, width], "counts": counts})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
||||||
"""Compute a binary mask from an uncompressed RLE."""
|
"""Compute a binary mask from an uncompressed RLE."""
|
||||||
height, width = rle["size"]
|
height, width = rle["size"]
|
||||||
@@ -836,7 +1221,7 @@ def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
|||||||
|
|
||||||
def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
||||||
"""
|
"""
|
||||||
Perform NMS (Non Maxium Suppression) on the outputs.
|
Perform NMS (Non Maximum Suppression) on the outputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rle_masks (`torch.Tensor`):
|
rle_masks (`torch.Tensor`):
|
||||||
@@ -861,3 +1246,32 @@ def _postprocess_for_mg(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=
|
|||||||
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
||||||
|
|
||||||
return masks, iou_scores, rle_masks, mask_boxes
|
return masks, iou_scores, rle_masks, mask_boxes
|
||||||
|
|
||||||
|
|
||||||
|
def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
|
||||||
|
"""
|
||||||
|
Perform NMS (Non Maximum Suppression) on the outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rle_masks (`tf.Tensor`):
|
||||||
|
binary masks in the RLE format
|
||||||
|
iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
|
||||||
|
iou_scores predicted by the model
|
||||||
|
mask_boxes (`tf.Tensor`):
|
||||||
|
The bounding boxes corresponding to segmentation masks
|
||||||
|
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
||||||
|
NMS threshold.
|
||||||
|
"""
|
||||||
|
keep_by_nms = tf.image.combined_non_max_suppression(
|
||||||
|
boxes=mask_boxes.float(),
|
||||||
|
scores=iou_scores,
|
||||||
|
idxs=torch.zeros(mask_boxes.shape[0]),
|
||||||
|
iou_threshold=amg_crops_nms_thresh,
|
||||||
|
)
|
||||||
|
|
||||||
|
iou_scores = iou_scores[keep_by_nms]
|
||||||
|
rle_masks = [rle_masks[i] for i in keep_by_nms]
|
||||||
|
mask_boxes = mask_boxes[keep_by_nms]
|
||||||
|
masks = [_rle_to_mask(rle) for rle in rle_masks]
|
||||||
|
|
||||||
|
return masks, iou_scores, rle_masks, mask_boxes
|
||||||
|
|||||||
@@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput):
|
|||||||
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from src.models.modeling_vit_mae.ViTMAEPatchEmbeddings with ViTMAEPatchEmbeddings->SamVisionEmbeddings,x->embeddings
|
|
||||||
class SamPatchEmbeddings(nn.Module):
|
class SamPatchEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||||
@@ -198,7 +197,7 @@ class SamAttention(nn.Module):
|
|||||||
values.
|
values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, downsample_rate=None) -> None:
|
def __init__(self, config, downsample_rate=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
@@ -252,7 +251,7 @@ class SamAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SamTwoWayAttentionBlock(nn.Module):
|
class SamTwoWayAttentionBlock(nn.Module):
|
||||||
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False) -> None:
|
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
|
||||||
"""
|
"""
|
||||||
A transformer block with four layers:
|
A transformer block with four layers:
|
||||||
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
|
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
|
||||||
@@ -476,7 +475,7 @@ class SamMaskDecoder(nn.Module):
|
|||||||
the embeddings of the mask inputs
|
the embeddings of the mask inputs
|
||||||
multimask_output (bool):
|
multimask_output (bool):
|
||||||
Whether to return multiple masks or a single mask.
|
Whether to return multiple masks or a single mask.
|
||||||
output_attentions (bool, **optional**):
|
output_attentions (bool, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers.
|
Whether or not to return the attentions tensors of all attention layers.
|
||||||
"""
|
"""
|
||||||
batch_size, num_channels, height, width = image_embeddings.shape
|
batch_size, num_channels, height, width = image_embeddings.shape
|
||||||
@@ -668,11 +667,11 @@ class SamPromptEncoder(nn.Module):
|
|||||||
Embeds different types of prompts, returning both sparse and dense embeddings.
|
Embeds different types of prompts, returning both sparse and dense embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
points (`torch.Tensor`, **optionnal**):
|
points (`torch.Tensor`, *optional*):
|
||||||
point coordinates and labels to embed.
|
point coordinates and labels to embed.
|
||||||
boxes (`torch.Tensor`, **optionnal**):
|
boxes (`torch.Tensor`, *optional*):
|
||||||
boxes to embed
|
boxes to embed
|
||||||
masks (`torch.Tensor`, **optionnal**):
|
masks (`torch.Tensor`, *optional*):
|
||||||
masks to embed
|
masks to embed
|
||||||
"""
|
"""
|
||||||
sparse_embeddings = None
|
sparse_embeddings = None
|
||||||
@@ -707,7 +706,7 @@ class SamPromptEncoder(nn.Module):
|
|||||||
class SamVisionAttention(nn.Module):
|
class SamVisionAttention(nn.Module):
|
||||||
"""Multi-head Attention block with relative position embeddings."""
|
"""Multi-head Attention block with relative position embeddings."""
|
||||||
|
|
||||||
def __init__(self, config, window_size) -> None:
|
def __init__(self, config, window_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
input_size = (
|
input_size = (
|
||||||
(config.image_size // config.patch_size, config.image_size // config.patch_size)
|
(config.image_size // config.patch_size, config.image_size // config.patch_size)
|
||||||
@@ -845,7 +844,7 @@ class SamVisionAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SamVisionLayer(nn.Module):
|
class SamVisionLayer(nn.Module):
|
||||||
def __init__(self, config, window_size) -> None:
|
def __init__(self, config, window_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.attn = SamVisionAttention(config, window_size)
|
self.attn = SamVisionAttention(config, window_size)
|
||||||
@@ -1166,7 +1165,7 @@ SAM_INPUTS_DOCSTRING = r"""
|
|||||||
class SamModel(SamPreTrainedModel):
|
class SamModel(SamPreTrainedModel):
|
||||||
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
|
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
|
||||||
|
|
||||||
def __init__(self, config) -> None:
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
|
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
|
||||||
|
|
||||||
@@ -1334,7 +1333,6 @@ class SamModel(SamPreTrainedModel):
|
|||||||
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
|
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
|
||||||
|
|
||||||
vision_attentions = None
|
vision_attentions = None
|
||||||
mask_decoder_attentions = None
|
|
||||||
vision_hidden_states = None
|
vision_hidden_states = None
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
@@ -1359,7 +1357,8 @@ class SamModel(SamPreTrainedModel):
|
|||||||
"The batch size of the image embeddings and the input points must be the same. ",
|
"The batch size of the image embeddings and the input points must be the same. ",
|
||||||
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
|
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
|
||||||
" if you want to pass multiple points for the same image, make sure that you passed ",
|
" if you want to pass multiple points for the same image, make sure that you passed ",
|
||||||
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
|
||||||
|
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
||||||
)
|
)
|
||||||
|
|
||||||
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
||||||
|
|||||||
1497
src/transformers/models/sam/modeling_tf_sam.py
Normal file
1497
src/transformers/models/sam/modeling_tf_sam.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -22,12 +22,15 @@ import numpy as np
|
|||||||
|
|
||||||
from ...processing_utils import ProcessorMixin
|
from ...processing_utils import ProcessorMixin
|
||||||
from ...tokenization_utils_base import BatchEncoding
|
from ...tokenization_utils_base import BatchEncoding
|
||||||
from ...utils import TensorType, is_torch_available
|
from ...utils import TensorType, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
class SamProcessor(ProcessorMixin):
|
class SamProcessor(ProcessorMixin):
|
||||||
r"""
|
r"""
|
||||||
@@ -72,7 +75,7 @@ class SamProcessor(ProcessorMixin):
|
|||||||
# pop arguments that are not used in the foward but used nevertheless
|
# pop arguments that are not used in the foward but used nevertheless
|
||||||
original_sizes = encoding_image_processor["original_sizes"]
|
original_sizes = encoding_image_processor["original_sizes"]
|
||||||
|
|
||||||
if isinstance(original_sizes, torch.Tensor):
|
if hasattr(original_sizes, "numpy"): # Checks if Torch or TF tensor
|
||||||
original_sizes = original_sizes.numpy()
|
original_sizes = original_sizes.numpy()
|
||||||
|
|
||||||
input_points, input_labels, input_boxes = self._check_and_preprocess_points(
|
input_points, input_labels, input_boxes = self._check_and_preprocess_points(
|
||||||
@@ -139,18 +142,30 @@ class SamProcessor(ProcessorMixin):
|
|||||||
input_boxes = torch.from_numpy(input_boxes)
|
input_boxes = torch.from_numpy(input_boxes)
|
||||||
# boxes batch size of 1 by default
|
# boxes batch size of 1 by default
|
||||||
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
|
input_boxes = input_boxes.unsqueeze(1) if len(input_boxes.shape) != 3 else input_boxes
|
||||||
|
elif return_tensors == "tf":
|
||||||
|
input_boxes = tf.convert_to_tensor(input_boxes)
|
||||||
|
# boxes batch size of 1 by default
|
||||||
|
input_boxes = tf.expand_dims(input_boxes, 1) if len(input_boxes.shape) != 3 else input_boxes
|
||||||
encoding_image_processor.update({"input_boxes": input_boxes})
|
encoding_image_processor.update({"input_boxes": input_boxes})
|
||||||
if input_points is not None:
|
if input_points is not None:
|
||||||
if return_tensors == "pt":
|
if return_tensors == "pt":
|
||||||
input_points = torch.from_numpy(input_points)
|
input_points = torch.from_numpy(input_points)
|
||||||
# point batch size of 1 by default
|
# point batch size of 1 by default
|
||||||
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
|
input_points = input_points.unsqueeze(1) if len(input_points.shape) != 4 else input_points
|
||||||
|
elif return_tensors == "tf":
|
||||||
|
input_points = tf.convert_to_tensor(input_points)
|
||||||
|
# point batch size of 1 by default
|
||||||
|
input_points = tf.expand_dims(input_points, 1) if len(input_points.shape) != 4 else input_points
|
||||||
encoding_image_processor.update({"input_points": input_points})
|
encoding_image_processor.update({"input_points": input_points})
|
||||||
if input_labels is not None:
|
if input_labels is not None:
|
||||||
if return_tensors == "pt":
|
if return_tensors == "pt":
|
||||||
input_labels = torch.from_numpy(input_labels)
|
input_labels = torch.from_numpy(input_labels)
|
||||||
# point batch size of 1 by default
|
# point batch size of 1 by default
|
||||||
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
|
input_labels = input_labels.unsqueeze(1) if len(input_labels.shape) != 3 else input_labels
|
||||||
|
elif return_tensors == "tf":
|
||||||
|
input_labels = tf.convert_to_tensor(input_labels)
|
||||||
|
# point batch size of 1 by default
|
||||||
|
input_labels = tf.expand_dims(input_labels, 1) if len(input_labels.shape) != 3 else input_labels
|
||||||
encoding_image_processor.update({"input_labels": input_labels})
|
encoding_image_processor.update({"input_labels": input_labels})
|
||||||
|
|
||||||
return encoding_image_processor
|
return encoding_image_processor
|
||||||
@@ -204,7 +219,7 @@ class SamProcessor(ProcessorMixin):
|
|||||||
it is converted to a `numpy.ndarray` and then to a `list`.
|
it is converted to a `numpy.ndarray` and then to a `list`.
|
||||||
"""
|
"""
|
||||||
if input_points is not None:
|
if input_points is not None:
|
||||||
if isinstance(input_points, torch.Tensor):
|
if hasattr(input_points, "numpy"): # Checks for TF or Torch tensor
|
||||||
input_points = input_points.numpy().tolist()
|
input_points = input_points.numpy().tolist()
|
||||||
|
|
||||||
if not isinstance(input_points, list) or not isinstance(input_points[0], list):
|
if not isinstance(input_points, list) or not isinstance(input_points[0], list):
|
||||||
@@ -214,7 +229,7 @@ class SamProcessor(ProcessorMixin):
|
|||||||
input_points = None
|
input_points = None
|
||||||
|
|
||||||
if input_labels is not None:
|
if input_labels is not None:
|
||||||
if isinstance(input_labels, torch.Tensor):
|
if hasattr(input_labels, "numpy"):
|
||||||
input_labels = input_labels.numpy().tolist()
|
input_labels = input_labels.numpy().tolist()
|
||||||
|
|
||||||
if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
|
if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
|
||||||
@@ -224,7 +239,7 @@ class SamProcessor(ProcessorMixin):
|
|||||||
input_labels = None
|
input_labels = None
|
||||||
|
|
||||||
if input_boxes is not None:
|
if input_boxes is not None:
|
||||||
if isinstance(input_boxes, torch.Tensor):
|
if hasattr(input_boxes, "numpy"):
|
||||||
input_boxes = input_boxes.numpy().tolist()
|
input_boxes = input_boxes.numpy().tolist()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -70,6 +70,56 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional
|
|||||||
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
|
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
|
||||||
|
# This is a very simplified functional layernorm, designed to duplicate
|
||||||
|
# the functionality of PyTorch nn.functional.layer_norm when this is needed to port
|
||||||
|
# models in Transformers.
|
||||||
|
|
||||||
|
if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int):
|
||||||
|
raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.")
|
||||||
|
|
||||||
|
# Get mean and variance on the axis to be normalized
|
||||||
|
mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True)
|
||||||
|
|
||||||
|
if axis != -1:
|
||||||
|
# Reshape scale and weight to have the same rank as inputs, but with 1 dimensions
|
||||||
|
# on every dimension except axis
|
||||||
|
shape = [1] * inputs.shape.rank
|
||||||
|
shape[axis] = shape_list(inputs)[axis]
|
||||||
|
weight = tf.reshape(weight, shape)
|
||||||
|
bias = tf.reshape(bias, shape)
|
||||||
|
|
||||||
|
# Compute layer normalization using the batch_normalization
|
||||||
|
# function.
|
||||||
|
outputs = tf.nn.batch_normalization(
|
||||||
|
inputs,
|
||||||
|
mean,
|
||||||
|
variance,
|
||||||
|
offset=bias,
|
||||||
|
scale=weight,
|
||||||
|
variance_epsilon=epsilon,
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(input, start_dim=0, end_dim=-1):
|
||||||
|
# Replicates the behavior of torch.flatten in TF
|
||||||
|
|
||||||
|
# If end_dim or start_dim is negative, count them from the end
|
||||||
|
if end_dim < 0:
|
||||||
|
end_dim += input.shape.rank
|
||||||
|
if start_dim < 0:
|
||||||
|
start_dim += input.shape.rank
|
||||||
|
|
||||||
|
if start_dim == end_dim:
|
||||||
|
return input
|
||||||
|
|
||||||
|
in_shape = tf.shape(input)
|
||||||
|
flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1])
|
||||||
|
out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0)
|
||||||
|
return tf.reshape(input, out_shape)
|
||||||
|
|
||||||
|
|
||||||
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
|
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
Invert an attention mask (e.g., switches 0. and 1.).
|
Invert an attention mask (e.g., switches 0. and 1.).
|
||||||
|
|||||||
@@ -2317,6 +2317,23 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class TFSamModel(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFSamPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -436,6 +436,9 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4):
|
||||||
|
super().test_pt_tf_model_equivalence(allow_missing_keys=True, tol=tol)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
@@ -470,8 +473,10 @@ class SamModelIntegrationTest(unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
scores = outputs.iou_scores.squeeze()
|
scores = outputs.iou_scores.squeeze()
|
||||||
|
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=1e-4))
|
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.5798), atol=2e-4))
|
||||||
|
self.assertTrue(torch.allclose(masks, torch.tensor([-6.6381, -6.0734, -7.5308]).to(torch_device), atol=2e-4))
|
||||||
|
|
||||||
def test_inference_mask_generation_one_point_one_bb(self):
|
def test_inference_mask_generation_one_point_one_bb(self):
|
||||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
@@ -491,8 +496,12 @@ class SamModelIntegrationTest(unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
scores = outputs.iou_scores.squeeze()
|
scores = outputs.iou_scores.squeeze()
|
||||||
|
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=1e-4))
|
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9935), atol=2e-4))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(masks, torch.tensor([-21.5465, -23.1122, -22.3331]).to(torch_device), atol=2e-4)
|
||||||
|
)
|
||||||
|
|
||||||
def test_inference_mask_generation_batched_points_batched_images(self):
|
def test_inference_mask_generation_batched_points_batched_images(self):
|
||||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
@@ -514,6 +523,7 @@ class SamModelIntegrationTest(unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
scores = outputs.iou_scores.squeeze().cpu()
|
scores = outputs.iou_scores.squeeze().cpu()
|
||||||
|
masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
|
||||||
|
|
||||||
EXPECTED_SCORES = torch.tensor(
|
EXPECTED_SCORES = torch.tensor(
|
||||||
[
|
[
|
||||||
@@ -531,7 +541,9 @@ class SamModelIntegrationTest(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
EXPECTED_MASKS = torch.tensor([-26.5424, -34.0901, -30.6406])
|
||||||
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
|
self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
|
||||||
|
self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
|
||||||
|
|
||||||
def test_inference_mask_generation_one_point_one_bb_zero(self):
|
def test_inference_mask_generation_one_point_one_bb_zero(self):
|
||||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|||||||
671
tests/models/sam/test_modeling_tf_sam.py
Normal file
671
tests/models/sam/test_modeling_tf_sam.py
Normal file
@@ -0,0 +1,671 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Testing suite for the TensorFlow SAM model. """
|
||||||
|
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
|
||||||
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
from transformers.utils import is_tf_available, is_vision_available
|
||||||
|
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from transformers import SamProcessor, TFSamModel
|
||||||
|
from transformers.models.sam.modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class TFSamPromptEncoderTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=32,
|
||||||
|
input_image_size=24,
|
||||||
|
patch_size=2,
|
||||||
|
mask_input_channels=4,
|
||||||
|
num_point_embeddings=4,
|
||||||
|
hidden_act="gelu",
|
||||||
|
):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.input_image_size = input_image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.mask_input_channels = mask_input_channels
|
||||||
|
self.num_point_embeddings = num_point_embeddings
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return SamPromptEncoderConfig(
|
||||||
|
image_size=self.input_image_size,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
mask_input_channels=self.mask_input_channels,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_point_embeddings=self.num_point_embeddings,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
dummy_points = floats_tensor([self.batch_size, 3, 2])
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, dummy_points
|
||||||
|
|
||||||
|
|
||||||
|
class TFSamMaskDecoderTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=32,
|
||||||
|
hidden_act="relu",
|
||||||
|
mlp_dim=64,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
attention_downsample_rate=2,
|
||||||
|
num_multimask_outputs=3,
|
||||||
|
iou_head_depth=3,
|
||||||
|
iou_head_hidden_dim=32,
|
||||||
|
layer_norm_eps=1e-6,
|
||||||
|
):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.mlp_dim = mlp_dim
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_downsample_rate = attention_downsample_rate
|
||||||
|
self.num_multimask_outputs = num_multimask_outputs
|
||||||
|
self.iou_head_depth = iou_head_depth
|
||||||
|
self.iou_head_hidden_dim = iou_head_hidden_dim
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return SamMaskDecoderConfig(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
mlp_dim=self.mlp_dim,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
attention_downsample_rate=self.attention_downsample_rate,
|
||||||
|
num_multimask_outputs=self.num_multimask_outputs,
|
||||||
|
iou_head_depth=self.iou_head_depth,
|
||||||
|
iou_head_hidden_dim=self.iou_head_hidden_dim,
|
||||||
|
layer_norm_eps=self.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
dummy_inputs = {
|
||||||
|
"image_embedding": floats_tensor([self.batch_size, self.hidden_size]),
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, dummy_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class TFSamModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
hidden_size=36,
|
||||||
|
intermediate_size=72,
|
||||||
|
projection_dim=62,
|
||||||
|
output_channels=32,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=24,
|
||||||
|
patch_size=2,
|
||||||
|
hidden_act="gelu",
|
||||||
|
layer_norm_eps=1e-06,
|
||||||
|
dropout=0.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
initializer_factor=1.0,
|
||||||
|
qkv_bias=True,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
use_abs_pos=True,
|
||||||
|
use_rel_pos=True,
|
||||||
|
rel_pos_zero_init=False,
|
||||||
|
window_size=14,
|
||||||
|
global_attn_indexes=[2, 5, 8, 11],
|
||||||
|
num_pos_feats=16,
|
||||||
|
mlp_dim=None,
|
||||||
|
batch_size=2,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.output_channels = output_channels
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.initializer_factor = initializer_factor
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.use_abs_pos = use_abs_pos
|
||||||
|
self.use_rel_pos = use_rel_pos
|
||||||
|
self.rel_pos_zero_init = rel_pos_zero_init
|
||||||
|
self.window_size = window_size
|
||||||
|
self.global_attn_indexes = global_attn_indexes
|
||||||
|
self.num_pos_feats = num_pos_feats
|
||||||
|
self.mlp_dim = mlp_dim
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||||
|
num_patches = (image_size // patch_size) ** 2
|
||||||
|
self.seq_length = num_patches + 1
|
||||||
|
|
||||||
|
self.prompt_encoder_tester = TFSamPromptEncoderTester()
|
||||||
|
self.mask_decoder_tester = TFSamMaskDecoderTester()
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
vision_config = SamVisionConfig(
|
||||||
|
image_size=self.image_size,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
num_channels=self.num_channels,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
projection_dim=self.projection_dim,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=self.attention_dropout,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
initializer_factor=self.initializer_factor,
|
||||||
|
output_channels=self.output_channels,
|
||||||
|
qkv_bias=self.qkv_bias,
|
||||||
|
mlp_ratio=self.mlp_ratio,
|
||||||
|
use_abs_pos=self.use_abs_pos,
|
||||||
|
use_rel_pos=self.use_rel_pos,
|
||||||
|
rel_pos_zero_init=self.rel_pos_zero_init,
|
||||||
|
window_size=self.window_size,
|
||||||
|
global_attn_indexes=self.global_attn_indexes,
|
||||||
|
num_pos_feats=self.num_pos_feats,
|
||||||
|
mlp_dim=self.mlp_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_encoder_config = self.prompt_encoder_tester.get_config()
|
||||||
|
|
||||||
|
mask_decoder_config = self.mask_decoder_tester.get_config()
|
||||||
|
|
||||||
|
return SamConfig(
|
||||||
|
vision_config=vision_config,
|
||||||
|
prompt_encoder_config=prompt_encoder_config,
|
||||||
|
mask_decoder_config=mask_decoder_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, pixel_values):
|
||||||
|
model = TFSamModel(config=config)
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
|
||||||
|
self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3))
|
||||||
|
|
||||||
|
def create_and_check_get_image_features(self, config, pixel_values):
|
||||||
|
model = TFSamModel(config=config)
|
||||||
|
result = model.get_image_embeddings(pixel_values)
|
||||||
|
self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12))
|
||||||
|
|
||||||
|
def create_and_check_get_image_hidden_states(self, config, pixel_values):
|
||||||
|
model = TFSamModel(config=config)
|
||||||
|
result = model.vision_encoder(
|
||||||
|
pixel_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# after computing the convolutional features
|
||||||
|
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
|
||||||
|
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
|
||||||
|
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
|
||||||
|
|
||||||
|
result = model.vision_encoder(
|
||||||
|
pixel_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# after computing the convolutional features
|
||||||
|
expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
|
||||||
|
self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
|
||||||
|
self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {"pixel_values": pixel_values}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
|
||||||
|
attention_mask and seq_length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_model_classes = (TFSamModel,) if is_tf_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{"feature-extraction": TFSamModel, "mask-generation": TFSamModel} if is_tf_available() else {}
|
||||||
|
)
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
|
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
|
||||||
|
def is_pipeline_test_to_skip(
|
||||||
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = TFSamModelTester(self)
|
||||||
|
self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
|
||||||
|
self.prompt_encoder_config_tester = ConfigTester(
|
||||||
|
self,
|
||||||
|
config_class=SamPromptEncoderConfig,
|
||||||
|
has_text_modality=False,
|
||||||
|
num_attention_heads=12,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
)
|
||||||
|
self.mask_decoder_config_tester = ConfigTester(
|
||||||
|
self, config_class=SamMaskDecoderConfig, has_text_modality=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.vision_config_tester.run_common_tests()
|
||||||
|
self.prompt_encoder_config_tester.run_common_tests()
|
||||||
|
self.mask_decoder_config_tester.run_common_tests()
|
||||||
|
|
||||||
|
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
|
||||||
|
x = model.get_output_embeddings()
|
||||||
|
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Dense))
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.call)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["pixel_values"]
|
||||||
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_get_image_features(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_get_image_features(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_image_hidden_states(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
expected_vision_attention_shape = (
|
||||||
|
self.model_tester.batch_size * self.model_tester.num_attention_heads,
|
||||||
|
196,
|
||||||
|
196,
|
||||||
|
)
|
||||||
|
expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = False
|
||||||
|
config.return_dict = True
|
||||||
|
model = model_class(config)
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
vision_attentions = outputs.vision_attentions
|
||||||
|
self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
mask_decoder_attentions = outputs.mask_decoder_attentions
|
||||||
|
self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
# check that output_attentions also work using config
|
||||||
|
del inputs_dict["output_attentions"]
|
||||||
|
config.output_attentions = True
|
||||||
|
model = model_class(config)
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
vision_attentions = outputs.vision_attentions
|
||||||
|
self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
mask_decoder_attentions = outputs.mask_decoder_attentions
|
||||||
|
self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(vision_attentions[0].shape[-4:]),
|
||||||
|
list(expected_vision_attention_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(mask_decoder_attentions[0].shape[-4:]),
|
||||||
|
list(expected_mask_decoder_attention_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
for model_name in TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
model = TFSamModel.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-4, name="outputs", attributes=None):
|
||||||
|
super().check_pt_tf_outputs(
|
||||||
|
tf_outputs=tf_outputs,
|
||||||
|
pt_outputs=pt_outputs,
|
||||||
|
model_class=model_class,
|
||||||
|
tol=tol,
|
||||||
|
name=name,
|
||||||
|
attributes=attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_image():
|
||||||
|
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||||
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||||
|
return raw_image
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dog_img():
|
||||||
|
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
|
||||||
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||||
|
return raw_image
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
class SamModelIntegrationTest(unittest.TestCase):
|
||||||
|
def test_inference_mask_generation_no_point(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
raw_image = prepare_image()
|
||||||
|
inputs = processor(images=raw_image, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores = tf.squeeze(outputs.iou_scores)
|
||||||
|
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=2e-4))
|
||||||
|
self.assertTrue(np.allclose(masks.numpy(), np.array([-6.6381, -6.0734, -7.5308]), atol=1e-2))
|
||||||
|
|
||||||
|
def test_inference_mask_generation_one_point_one_bb(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
raw_image = prepare_image()
|
||||||
|
input_boxes = [[[650, 900, 1000, 1250]]]
|
||||||
|
input_points = [[[820, 1080]]]
|
||||||
|
|
||||||
|
inputs = processor(images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores = tf.squeeze(outputs.iou_scores)
|
||||||
|
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(scores[-1], np.array(0.9935), atol=2e-4))
|
||||||
|
self.assertTrue(np.allclose(masks.numpy(), np.array([-21.5465, -23.1122, -22.3331]), atol=2e-2))
|
||||||
|
|
||||||
|
def test_inference_mask_generation_batched_points_batched_images(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
raw_image = prepare_image()
|
||||||
|
input_points = [
|
||||||
|
[[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
|
||||||
|
[[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores = tf.squeeze(outputs.iou_scores)
|
||||||
|
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||||
|
|
||||||
|
EXPECTED_SCORES = np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.9673, 0.9441, 0.9084],
|
||||||
|
[0.9673, 0.9441, 0.9084],
|
||||||
|
[0.9673, 0.9441, 0.9084],
|
||||||
|
[0.9673, 0.9441, 0.9084],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0.8405, 0.6292, 0.3840],
|
||||||
|
[0.9673, 0.9441, 0.9084],
|
||||||
|
[0.9673, 0.9441, 0.9084],
|
||||||
|
[0.9673, 0.9441, 0.9084],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
EXPECTED_MASKS = np.array([-26.5424, -34.0901, -30.6406])
|
||||||
|
self.assertTrue(np.allclose(scores.numpy(), EXPECTED_SCORES, atol=1e-3))
|
||||||
|
self.assertTrue(np.allclose(masks.numpy(), EXPECTED_MASKS, atol=3e-2))
|
||||||
|
|
||||||
|
def test_inference_mask_generation_one_point_one_bb_zero(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
raw_image = prepare_image()
|
||||||
|
input_boxes = [[[620, 900, 1000, 1255]]]
|
||||||
|
input_points = [[[820, 1080]]]
|
||||||
|
labels = [[0]]
|
||||||
|
|
||||||
|
inputs = processor(
|
||||||
|
images=raw_image,
|
||||||
|
input_boxes=input_boxes,
|
||||||
|
input_points=input_points,
|
||||||
|
input_labels=labels,
|
||||||
|
return_tensors="tf",
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores = tf.squeeze(outputs.iou_scores)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9689), atol=1e-4))
|
||||||
|
|
||||||
|
def test_inference_mask_generation_one_point(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
raw_image = prepare_image()
|
||||||
|
|
||||||
|
input_points = [[[400, 650]]]
|
||||||
|
input_labels = [[1]]
|
||||||
|
|
||||||
|
inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores = tf.squeeze(outputs.iou_scores)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(scores[-1], np.array(0.9712), atol=1e-4))
|
||||||
|
|
||||||
|
# With no label
|
||||||
|
input_points = [[[400, 650]]]
|
||||||
|
|
||||||
|
inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores = tf.squeeze(outputs.iou_scores)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9712), atol=1e-4))
|
||||||
|
|
||||||
|
def test_inference_mask_generation_two_points(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
raw_image = prepare_image()
|
||||||
|
|
||||||
|
input_points = [[[400, 650], [800, 650]]]
|
||||||
|
input_labels = [[1, 1]]
|
||||||
|
|
||||||
|
inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores = tf.squeeze(outputs.iou_scores)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4))
|
||||||
|
|
||||||
|
# no labels
|
||||||
|
inputs = processor(images=raw_image, input_points=input_points, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores = tf.squeeze(outputs.iou_scores)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.9936), atol=1e-4))
|
||||||
|
|
||||||
|
def test_inference_mask_generation_two_points_batched(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
raw_image = prepare_image()
|
||||||
|
|
||||||
|
input_points = [[[400, 650], [800, 650]], [[400, 650]]]
|
||||||
|
input_labels = [[1, 1], [1]]
|
||||||
|
|
||||||
|
inputs = processor(
|
||||||
|
images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="tf"
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores = tf.squeeze(outputs.iou_scores)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(scores[0][-1].numpy(), np.array(0.9936), atol=1e-4))
|
||||||
|
self.assertTrue(np.allclose(scores[1][-1], np.array(0.9716), atol=1e-4))
|
||||||
|
|
||||||
|
def test_inference_mask_generation_one_box(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
raw_image = prepare_image()
|
||||||
|
|
||||||
|
input_boxes = [[[75, 275, 1725, 850]]]
|
||||||
|
|
||||||
|
inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores = tf.squeeze(outputs.iou_scores)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.8686), atol=1e-4))
|
||||||
|
|
||||||
|
def test_inference_mask_generation_batched_image_one_point(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
raw_image = prepare_image()
|
||||||
|
raw_dog_image = prepare_dog_img()
|
||||||
|
|
||||||
|
input_points = [[[820, 1080]], [[220, 470]]]
|
||||||
|
|
||||||
|
inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores_batched = tf.squeeze(outputs.iou_scores)
|
||||||
|
|
||||||
|
input_points = [[[220, 470]]]
|
||||||
|
|
||||||
|
inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
scores_single = tf.squeeze(outputs.iou_scores)
|
||||||
|
self.assertTrue(np.allclose(scores_batched[1, :].numpy(), scores_single.numpy(), atol=1e-4))
|
||||||
|
|
||||||
|
def test_inference_mask_generation_two_points_point_batch(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
raw_image = prepare_image()
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_points = tf.convert_to_tensor([[[400, 650]], [[220, 470]]])
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
input_points = tf.expand_dims(input_points, 0)
|
||||||
|
|
||||||
|
inputs = processor(raw_image, input_points=input_points, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
iou_scores = outputs.iou_scores
|
||||||
|
self.assertTrue(iou_scores.shape == (1, 2, 3))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(
|
||||||
|
iou_scores.numpy(),
|
||||||
|
np.array([[[0.9848, 0.9788, 0.9713], [0.9211, 0.9128, 0.7427]]]),
|
||||||
|
atol=1e-4,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_inference_mask_generation_three_boxes_point_batch(self):
|
||||||
|
model = TFSamModel.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
|
|
||||||
|
raw_image = prepare_image()
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
input_boxes = tf.convert_to_tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]])
|
||||||
|
EXPECTED_IOU = np.array([[[1.0071, 1.0032, 0.9946], [0.4962, 0.8770, 0.8686], [0.4962, 0.8770, 0.8686]]])
|
||||||
|
# fmt: on
|
||||||
|
input_boxes = tf.expand_dims(input_boxes, 0)
|
||||||
|
|
||||||
|
inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="tf")
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
iou_scores = outputs.iou_scores
|
||||||
|
self.assertTrue(iou_scores.shape == (1, 3, 3))
|
||||||
|
self.assertTrue(np.allclose(iou_scores.numpy(), EXPECTED_IOU, atol=1e-4, rtol=1e-4))
|
||||||
@@ -17,8 +17,14 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch, require_torchvision, require_vision
|
from transformers.testing_utils import (
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
is_pt_tf_cross_test,
|
||||||
|
require_tf,
|
||||||
|
require_torch,
|
||||||
|
require_torchvision,
|
||||||
|
require_vision,
|
||||||
|
)
|
||||||
|
from transformers.utils import is_tf_available, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -29,6 +35,9 @@ if is_vision_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_torchvision
|
@require_torchvision
|
||||||
@@ -110,3 +119,158 @@ class SamProcessorTest(unittest.TestCase):
|
|||||||
dummy_masks = [[1, 0], [0, 1]]
|
dummy_masks = [[1, 0], [0, 1]]
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
|
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@require_tf
|
||||||
|
class TFSamProcessorTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
image_processor = SamImageProcessor()
|
||||||
|
processor = SamProcessor(image_processor)
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_image_processor(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def prepare_image_inputs(self):
|
||||||
|
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||||
|
or a list of PyTorch tensors if one specifies torchify=True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||||
|
|
||||||
|
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
|
def test_save_load_pretrained_additional_features(self):
|
||||||
|
processor = SamProcessor(image_processor=self.get_image_processor())
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
|
||||||
|
|
||||||
|
processor = SamProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0)
|
||||||
|
|
||||||
|
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
|
||||||
|
self.assertIsInstance(processor.image_processor, SamImageProcessor)
|
||||||
|
|
||||||
|
def test_image_processor(self):
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
|
||||||
|
processor = SamProcessor(image_processor=image_processor)
|
||||||
|
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
input_feat_extract = image_processor(image_input, return_tensors="np")
|
||||||
|
input_processor = processor(images=image_input, return_tensors="np")
|
||||||
|
|
||||||
|
input_feat_extract.pop("original_sizes") # pop original_sizes as it is popped in the processor
|
||||||
|
input_feat_extract.pop("reshaped_input_sizes") # pop reshaped_input_sizes as it is popped in the processor
|
||||||
|
|
||||||
|
for key in input_feat_extract.keys():
|
||||||
|
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_post_process_masks(self):
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
|
||||||
|
processor = SamProcessor(image_processor=image_processor)
|
||||||
|
dummy_masks = [tf.ones((1, 3, 5, 5))]
|
||||||
|
|
||||||
|
original_sizes = [[1764, 2646]]
|
||||||
|
|
||||||
|
reshaped_input_size = [[683, 1024]]
|
||||||
|
masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf")
|
||||||
|
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
|
||||||
|
|
||||||
|
masks = processor.post_process_masks(
|
||||||
|
dummy_masks,
|
||||||
|
tf.convert_to_tensor(original_sizes),
|
||||||
|
tf.convert_to_tensor(reshaped_input_size),
|
||||||
|
return_tensors="tf",
|
||||||
|
)
|
||||||
|
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
|
||||||
|
|
||||||
|
# should also work with np
|
||||||
|
dummy_masks = [np.ones((1, 3, 5, 5))]
|
||||||
|
masks = processor.post_process_masks(
|
||||||
|
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
|
||||||
|
|
||||||
|
dummy_masks = [[1, 0], [0, 1]]
|
||||||
|
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||||
|
masks = processor.post_process_masks(
|
||||||
|
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@require_torchvision
|
||||||
|
class SamProcessorEquivalenceTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
image_processor = SamImageProcessor()
|
||||||
|
processor = SamProcessor(image_processor)
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_image_processor(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def prepare_image_inputs(self):
|
||||||
|
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||||
|
or a list of PyTorch tensors if one specifies torchify=True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||||
|
|
||||||
|
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_post_process_masks_equivalence(self):
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
|
||||||
|
processor = SamProcessor(image_processor=image_processor)
|
||||||
|
dummy_masks = np.random.randint(0, 2, size=(1, 3, 5, 5)).astype(np.float32)
|
||||||
|
tf_dummy_masks = [tf.convert_to_tensor(dummy_masks)]
|
||||||
|
pt_dummy_masks = [torch.tensor(dummy_masks)]
|
||||||
|
|
||||||
|
original_sizes = [[1764, 2646]]
|
||||||
|
|
||||||
|
reshaped_input_size = [[683, 1024]]
|
||||||
|
tf_masks = processor.post_process_masks(
|
||||||
|
tf_dummy_masks, original_sizes, reshaped_input_size, return_tensors="tf"
|
||||||
|
)
|
||||||
|
pt_masks = processor.post_process_masks(
|
||||||
|
pt_dummy_masks, original_sizes, reshaped_input_size, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy()))
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_image_processor_equivalence(self):
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
|
||||||
|
processor = SamProcessor(image_processor=image_processor)
|
||||||
|
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
pt_input_feat_extract = image_processor(image_input, return_tensors="pt")["pixel_values"].numpy()
|
||||||
|
pt_input_processor = processor(images=image_input, return_tensors="pt")["pixel_values"].numpy()
|
||||||
|
|
||||||
|
tf_input_feat_extract = image_processor(image_input, return_tensors="tf")["pixel_values"].numpy()
|
||||||
|
tf_input_processor = processor(images=image_input, return_tensors="tf")["pixel_values"].numpy()
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(pt_input_feat_extract, pt_input_processor))
|
||||||
|
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_feat_extract))
|
||||||
|
self.assertTrue(np.allclose(pt_input_feat_extract, tf_input_processor))
|
||||||
|
|||||||
Reference in New Issue
Block a user