Add PerSAM [bis] (#23659)
* Add PerSAM args * Make attn_sim optional * Rename to attention_similarity * Add docstrigns * Improve docstrings
This commit is contained in:
@@ -224,7 +224,7 @@ class SamAttention(nn.Module):
|
|||||||
hidden_states = hidden_states.transpose(1, 2)
|
hidden_states = hidden_states.transpose(1, 2)
|
||||||
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
||||||
|
|
||||||
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
|
def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
|
||||||
# Input projections
|
# Input projections
|
||||||
query = self.q_proj(query)
|
query = self.q_proj(query)
|
||||||
key = self.k_proj(key)
|
key = self.k_proj(key)
|
||||||
@@ -242,6 +242,10 @@ class SamAttention(nn.Module):
|
|||||||
attn = attn / math.sqrt(c_per_head)
|
attn = attn / math.sqrt(c_per_head)
|
||||||
attn = torch.softmax(attn, dim=-1)
|
attn = torch.softmax(attn, dim=-1)
|
||||||
|
|
||||||
|
if attention_similarity is not None:
|
||||||
|
attn = attn + attention_similarity
|
||||||
|
attn = torch.softmax(attn, dim=-1)
|
||||||
|
|
||||||
# Get output
|
# Get output
|
||||||
out = attn @ value
|
out = attn @ value
|
||||||
out = self._recombine_heads(out, point_batch_size)
|
out = self._recombine_heads(out, point_batch_size)
|
||||||
@@ -290,6 +294,7 @@ class SamTwoWayAttentionBlock(nn.Module):
|
|||||||
keys: Tensor,
|
keys: Tensor,
|
||||||
query_point_embedding: Tensor,
|
query_point_embedding: Tensor,
|
||||||
key_point_embedding: Tensor,
|
key_point_embedding: Tensor,
|
||||||
|
attention_similarity: Tensor,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
):
|
):
|
||||||
# Self attention block
|
# Self attention block
|
||||||
@@ -305,7 +310,9 @@ class SamTwoWayAttentionBlock(nn.Module):
|
|||||||
query = queries + query_point_embedding
|
query = queries + query_point_embedding
|
||||||
key = keys + key_point_embedding
|
key = keys + key_point_embedding
|
||||||
|
|
||||||
attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
|
attn_out = self.cross_attn_token_to_image(
|
||||||
|
query=query, key=key, value=keys, attention_similarity=attention_similarity
|
||||||
|
)
|
||||||
queries = queries + attn_out
|
queries = queries + attn_out
|
||||||
|
|
||||||
queries = self.layer_norm2(queries)
|
queries = self.layer_norm2(queries)
|
||||||
@@ -353,6 +360,8 @@ class SamTwoWayTransformer(nn.Module):
|
|||||||
point_embeddings: Tensor,
|
point_embeddings: Tensor,
|
||||||
image_embeddings: Tensor,
|
image_embeddings: Tensor,
|
||||||
image_positional_embeddings: Tensor,
|
image_positional_embeddings: Tensor,
|
||||||
|
attention_similarity: Tensor,
|
||||||
|
target_embedding=None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
@@ -377,11 +386,15 @@ class SamTwoWayTransformer(nn.Module):
|
|||||||
|
|
||||||
# Apply transformer blocks and final layernorm
|
# Apply transformer blocks and final layernorm
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
|
if target_embedding is not None:
|
||||||
|
queries += target_embedding
|
||||||
|
|
||||||
queries, keys, attention_outputs = layer(
|
queries, keys, attention_outputs = layer(
|
||||||
queries=queries,
|
queries=queries,
|
||||||
keys=keys,
|
keys=keys,
|
||||||
query_point_embedding=point_embeddings,
|
query_point_embedding=point_embeddings,
|
||||||
key_point_embedding=image_positional_embeddings,
|
key_point_embedding=image_positional_embeddings,
|
||||||
|
attention_similarity=attention_similarity,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -460,6 +473,8 @@ class SamMaskDecoder(nn.Module):
|
|||||||
dense_prompt_embeddings: torch.Tensor,
|
dense_prompt_embeddings: torch.Tensor,
|
||||||
multimask_output: bool,
|
multimask_output: bool,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
|
attention_similarity: torch.Tensor = None,
|
||||||
|
target_embedding: torch.Tensor = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Predict masks given image and prompt embeddings.
|
Predict masks given image and prompt embeddings.
|
||||||
@@ -500,6 +515,8 @@ class SamMaskDecoder(nn.Module):
|
|||||||
point_embeddings=point_embeddings,
|
point_embeddings=point_embeddings,
|
||||||
image_embeddings=image_embeddings,
|
image_embeddings=image_embeddings,
|
||||||
image_positional_embeddings=image_positional_embeddings,
|
image_positional_embeddings=image_positional_embeddings,
|
||||||
|
attention_similarity=attention_similarity,
|
||||||
|
target_embedding=target_embedding,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
iou_token_out = point_embedding[:, :, 0, :]
|
iou_token_out = point_embedding[:, :, 0, :]
|
||||||
@@ -576,8 +593,12 @@ class SamMaskEmbedding(nn.Module):
|
|||||||
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
|
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
|
||||||
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
|
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
|
||||||
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
|
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
|
||||||
self.layer_norm1 = SamLayerNorm(self.mask_input_channels, config.layer_norm_eps)
|
self.layer_norm1 = SamLayerNorm(
|
||||||
self.layer_norm2 = SamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps)
|
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
|
||||||
|
)
|
||||||
|
self.layer_norm2 = SamLayerNorm(
|
||||||
|
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, masks):
|
def forward(self, masks):
|
||||||
hidden_states = self.conv1(masks)
|
hidden_states = self.conv1(masks)
|
||||||
@@ -1146,6 +1167,12 @@ SAM_INPUTS_DOCSTRING = r"""
|
|||||||
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
|
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
|
||||||
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
|
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
|
||||||
"best" mask, by specifying `multimask_output=False`.
|
"best" mask, by specifying `multimask_output=False`.
|
||||||
|
attention_similarity (`torch.FloatTensor`, *optional*):
|
||||||
|
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
|
||||||
|
model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
|
||||||
|
target_embedding (`torch.FloatTensor`, *optional*):
|
||||||
|
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
|
||||||
|
the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
tensors for more detail.
|
tensors for more detail.
|
||||||
@@ -1265,6 +1292,8 @@ class SamModel(SamPreTrainedModel):
|
|||||||
input_masks: Optional[torch.LongTensor] = None,
|
input_masks: Optional[torch.LongTensor] = None,
|
||||||
image_embeddings: Optional[torch.FloatTensor] = None,
|
image_embeddings: Optional[torch.FloatTensor] = None,
|
||||||
multimask_output: bool = True,
|
multimask_output: bool = True,
|
||||||
|
attention_similarity: Optional[torch.FloatTensor] = None,
|
||||||
|
target_embedding: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -1374,6 +1403,8 @@ class SamModel(SamPreTrainedModel):
|
|||||||
sparse_prompt_embeddings=sparse_embeddings,
|
sparse_prompt_embeddings=sparse_embeddings,
|
||||||
dense_prompt_embeddings=dense_embeddings,
|
dense_prompt_embeddings=dense_embeddings,
|
||||||
multimask_output=multimask_output,
|
multimask_output=multimask_output,
|
||||||
|
attention_similarity=attention_similarity,
|
||||||
|
target_embedding=target_embedding,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user