TF port of the Segment Anything Model (SAM) (#22970)
* First commit * Add auto-translation with GPT-4 * make fixup * Add a functional layernorm for TF * Add all the auxiliary imports etc. * Add the extra processor and tests * rebase to main * Add all the needed fixes to the GPT code * make fixup * Make convolutions channels-last so they run on CPU * make fixup * Fix final issues * Fix other models affected by test change * Clarify comment on the sparse_prompt_embeddings check * Refactor functional_layernorm, use shape_list in place of .shape in some places * Remove deprecated torch-alike code * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Refactor processor with common methods and separated private methods * make fixup * Quietly delete the file that didn't do anything (sorry Sylvain) * Refactor the processor tests into one file * make fixup * Clean up some unnecessary indirection * Fix TF mask postprocessing * Add more processor equivalence tests * Refactor generate_crop_boxes to use framework-neutral np code * Make the serving output correctly conditional * Fix error message line length * Use dict keys rather than indices internally in both TF and PT SAM call/forward * Return dicts internally in the call/forward methods * Revert changes to common tests and just override check_pt_tf_outputs * Revert changes to other model tests * Clarify comments for functional layernorm * Add missing transpose from PT code * Removed unused copied from in PT code * Remove overrides for tests that don't exist in TF * Fix transpose and update tests for PT and TF to check pred_masks * Add training flag * Update tests to use TF checkpoints * Update index.mdx * Add missing cross-test decorator * Remove optional extra asterisks * Revert return_dict changes in PT code * Update src/transformers/models/sam/modeling_tf_sam.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove None return annotations on init methods * Update tests/models/sam/test_processor_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix input_boxes shapes * make fixup --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -111,7 +111,6 @@ class SamImageSegmentationOutput(ModelOutput):
|
||||
mask_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
# Copied from src.models.modeling_vit_mae.ViTMAEPatchEmbeddings with ViTMAEPatchEmbeddings->SamVisionEmbeddings,x->embeddings
|
||||
class SamPatchEmbeddings(nn.Module):
|
||||
"""
|
||||
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||||
@@ -198,7 +197,7 @@ class SamAttention(nn.Module):
|
||||
values.
|
||||
"""
|
||||
|
||||
def __init__(self, config, downsample_rate=None) -> None:
|
||||
def __init__(self, config, downsample_rate=None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
@@ -252,7 +251,7 @@ class SamAttention(nn.Module):
|
||||
|
||||
|
||||
class SamTwoWayAttentionBlock(nn.Module):
|
||||
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False) -> None:
|
||||
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
|
||||
"""
|
||||
A transformer block with four layers:
|
||||
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
|
||||
@@ -476,7 +475,7 @@ class SamMaskDecoder(nn.Module):
|
||||
the embeddings of the mask inputs
|
||||
multimask_output (bool):
|
||||
Whether to return multiple masks or a single mask.
|
||||
output_attentions (bool, **optional**):
|
||||
output_attentions (bool, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers.
|
||||
"""
|
||||
batch_size, num_channels, height, width = image_embeddings.shape
|
||||
@@ -668,11 +667,11 @@ class SamPromptEncoder(nn.Module):
|
||||
Embeds different types of prompts, returning both sparse and dense embeddings.
|
||||
|
||||
Args:
|
||||
points (`torch.Tensor`, **optionnal**):
|
||||
points (`torch.Tensor`, *optional*):
|
||||
point coordinates and labels to embed.
|
||||
boxes (`torch.Tensor`, **optionnal**):
|
||||
boxes (`torch.Tensor`, *optional*):
|
||||
boxes to embed
|
||||
masks (`torch.Tensor`, **optionnal**):
|
||||
masks (`torch.Tensor`, *optional*):
|
||||
masks to embed
|
||||
"""
|
||||
sparse_embeddings = None
|
||||
@@ -707,7 +706,7 @@ class SamPromptEncoder(nn.Module):
|
||||
class SamVisionAttention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
def __init__(self, config, window_size) -> None:
|
||||
def __init__(self, config, window_size):
|
||||
super().__init__()
|
||||
input_size = (
|
||||
(config.image_size // config.patch_size, config.image_size // config.patch_size)
|
||||
@@ -845,7 +844,7 @@ class SamVisionAttention(nn.Module):
|
||||
|
||||
|
||||
class SamVisionLayer(nn.Module):
|
||||
def __init__(self, config, window_size) -> None:
|
||||
def __init__(self, config, window_size):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attn = SamVisionAttention(config, window_size)
|
||||
@@ -1166,7 +1165,7 @@ SAM_INPUTS_DOCSTRING = r"""
|
||||
class SamModel(SamPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
|
||||
|
||||
@@ -1334,7 +1333,6 @@ class SamModel(SamPreTrainedModel):
|
||||
image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
|
||||
|
||||
vision_attentions = None
|
||||
mask_decoder_attentions = None
|
||||
vision_hidden_states = None
|
||||
|
||||
if pixel_values is not None:
|
||||
@@ -1359,7 +1357,8 @@ class SamModel(SamPreTrainedModel):
|
||||
"The batch size of the image embeddings and the input points must be the same. ",
|
||||
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
|
||||
" if you want to pass multiple points for the same image, make sure that you passed ",
|
||||
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
||||
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
|
||||
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
|
||||
)
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
||||
|
||||
Reference in New Issue
Block a user