TF port of the Segment Anything Model (SAM) (#22970)

* First commit

* Add auto-translation with GPT-4

* make fixup

* Add a functional layernorm for TF

* Add all the auxiliary imports etc.

* Add the extra processor and tests

* rebase to main

* Add all the needed fixes to the GPT code

* make fixup

* Make convolutions channels-last so they run on CPU

* make fixup

* Fix final issues

* Fix other models affected by test change

* Clarify comment on the sparse_prompt_embeddings check

* Refactor functional_layernorm, use shape_list in place of .shape in some places

* Remove deprecated torch-alike code

* Update tests/models/sam/test_modeling_tf_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/sam/test_modeling_tf_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Refactor processor with common methods and separated private methods

* make fixup

* Quietly delete the file that didn't do anything (sorry Sylvain)

* Refactor the processor tests into one file

* make fixup

* Clean up some unnecessary indirection

* Fix TF mask postprocessing

* Add more processor equivalence tests

* Refactor generate_crop_boxes to use framework-neutral np code

* Make the serving output correctly conditional

* Fix error message line length

* Use dict keys rather than indices internally in both TF and PT SAM call/forward

* Return dicts internally in the call/forward methods

* Revert changes to common tests and just override check_pt_tf_outputs

* Revert changes to other model tests

* Clarify comments for functional layernorm

* Add missing transpose from PT code

* Removed unused copied from in PT code

* Remove overrides for tests that don't exist in TF

* Fix transpose and update tests for PT and TF to check pred_masks

* Add training flag

* Update tests to use TF checkpoints

* Update index.mdx

* Add missing cross-test decorator

* Remove optional extra asterisks

* Revert return_dict changes in PT code

* Update src/transformers/models/sam/modeling_tf_sam.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove None return annotations on init methods

* Update tests/models/sam/test_processor_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix input_boxes shapes

* make fixup

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Matt
2023-05-19 14:14:13 +01:00
committed by GitHub
parent 8aa8513f71
commit 1c460a5273
14 changed files with 2940 additions and 44 deletions

View File

@@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput):
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from src.models.modeling_vit_mae.ViTMAEPatchEmbeddings with ViTMAEPatchEmbeddings->SamVisionEmbeddings,x->embeddings
class SamPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
@@ -198,7 +197,7 @@ class SamAttention(nn.Module):
values.
"""
def __init__(self, config, downsample_rate=None) -> None:
def __init__(self, config, downsample_rate=None):
super().__init__()
self.hidden_size = config.hidden_size
@@ -252,7 +251,7 @@ class SamAttention(nn.Module):
class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False) -> None:
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
A transformer block with four layers:
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
@@ -476,7 +475,7 @@ class SamMaskDecoder(nn.Module):
the embeddings of the mask inputs
multimask_output (bool):
Whether to return multiple masks or a single mask.
output_attentions (bool, **optional**):
output_attentions (bool, *optional*):
Whether or not to return the attentions tensors of all attention layers.
"""
batch_size, num_channels, height, width = image_embeddings.shape
@@ -668,11 +667,11 @@ class SamPromptEncoder(nn.Module):
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (`torch.Tensor`, **optionnal**):
points (`torch.Tensor`, *optional*):
point coordinates and labels to embed.
boxes (`torch.Tensor`, **optionnal**):
boxes (`torch.Tensor`, *optional*):
boxes to embed
masks (`torch.Tensor`, **optionnal**):
masks (`torch.Tensor`, *optional*):
masks to embed
"""
sparse_embeddings = None
@@ -707,7 +706,7 @@ class SamPromptEncoder(nn.Module):
class SamVisionAttention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(self, config, window_size) -> None:
def __init__(self, config, window_size):
super().__init__()
input_size = (
(config.image_size // config.patch_size, config.image_size // config.patch_size)
@@ -845,7 +844,7 @@ class SamVisionAttention(nn.Module):
class SamVisionLayer(nn.Module):
def __init__(self, config, window_size) -> None:
def __init__(self, config, window_size):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = SamVisionAttention(config, window_size)
@@ -1166,7 +1165,7 @@ SAM_INPUTS_DOCSTRING = r"""
class SamModel(SamPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config) -> None:
def __init__(self, config):
super().__init__(config)
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
@@ -1334,7 +1333,6 @@ class SamModel(SamPreTrainedModel):
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
vision_attentions = None
mask_decoder_attentions = None
vision_hidden_states = None
if pixel_values is not None:
@@ -1359,7 +1357,8 @@ class SamModel(SamPreTrainedModel):
"The batch size of the image embeddings and the input points must be the same. ",
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
" if you want to pass multiple points for the same image, make sure that you passed ",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
)
sparse_embeddings, dense_embeddings = self.prompt_encoder(