Update SAM/SAM HQ attention implementation + fix Cuda sync issues (#39386)

* update attention implementation and improve inference speed

* modular sam_hq + fix integration tests on A10

* fixup

* fix after review

* softmax in correct place

* return attn_weights in sam/sam_hq
This commit is contained in:
Yoni Gozlan
2025-07-18 18:46:27 -04:00
committed by GitHub
parent 541bed22d6
commit 433d2a23d7
4 changed files with 149 additions and 187 deletions

View File

@@ -16,7 +16,7 @@
import collections import collections
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Callable, Optional, Union
import numpy as np import numpy as np
import torch import torch
@@ -28,7 +28,7 @@ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
@@ -178,6 +178,28 @@ class SamLayerNorm(nn.Module):
return x return x
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class SamAttention(nn.Module): class SamAttention(nn.Module):
""" """
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
@@ -186,6 +208,7 @@ class SamAttention(nn.Module):
def __init__(self, config, downsample_rate=None): def __init__(self, config, downsample_rate=None):
super().__init__() super().__init__()
self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
@@ -194,12 +217,15 @@ class SamAttention(nn.Module):
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
if self.internal_dim % config.num_attention_heads != 0: if self.internal_dim % config.num_attention_heads != 0:
raise ValueError("num_attention_heads must divide hidden_size.") raise ValueError("num_attention_heads must divide hidden_size.")
self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
self.is_causal = False
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads c_per_head = channel // num_attention_heads
@@ -207,12 +233,16 @@ class SamAttention(nn.Module):
return hidden_states.transpose(1, 2) return hidden_states.transpose(1, 2)
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_heads, n_tokens, c_per_head = hidden_states.shape batch, n_tokens, n_heads, c_per_head = hidden_states.shape
hidden_states = hidden_states.transpose(1, 2)
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward( def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_similarity: Optional[Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Tensor: ) -> Tensor:
# Input projections # Input projections
query = self.q_proj(query) query = self.q_proj(query)
@@ -226,64 +256,26 @@ class SamAttention(nn.Module):
value = self._separate_heads(value, self.num_attention_heads) value = self._separate_heads(value, self.num_attention_heads)
# SamAttention # SamAttention
_, _, _, c_per_head = query.shape attention_interface: Callable = eager_attention_forward
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens if self.config._attn_implementation != "eager":
attn = attn / (c_per_head**0.5) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn = torch.softmax(attn, dim=-1)
if attention_similarity is not None: attn_output, attn_weights = attention_interface(
attn = attn + attention_similarity self,
attn = torch.softmax(attn, dim=-1) query,
key,
value,
attention_mask=attention_similarity,
dropout=0.0 if not self.training else self.dropout_p,
scaling=self.scaling,
is_causal=self.is_causal,
**kwargs,
)
# Get output attn_output = self._recombine_heads(attn_output, point_batch_size)
out = attn @ value attn_output = self.out_proj(attn_output)
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out return attn_output, attn_weights
class SamSdpaAttention(SamAttention):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values. Using SDPA instead of the default attention.
"""
def __init__(self, config, downsample_rate=None):
super().__init__(config, downsample_rate)
def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)
# Scaled dot product attention
attn_mask = None
if attention_similarity is not None:
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
# Get output
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
SAM_ATTENTION_CLASSES = {
"eager": SamAttention,
"sdpa": SamSdpaAttention,
}
class SamTwoWayAttentionBlock(nn.Module): class SamTwoWayAttentionBlock(nn.Module):
@@ -306,21 +298,17 @@ class SamTwoWayAttentionBlock(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps self.layer_norm_eps = config.layer_norm_eps
self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1) self.self_attn = SamAttention(config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation]( self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
config, downsample_rate=attention_downsample_rate
)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.mlp = SamMLPBlock(config) self.mlp = SamMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation]( self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
config, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe self.skip_first_layer_pe = skip_first_layer_pe
def forward( def forward(
@@ -330,13 +318,14 @@ class SamTwoWayAttentionBlock(nn.Module):
query_point_embedding: Tensor, query_point_embedding: Tensor,
key_point_embedding: Tensor, key_point_embedding: Tensor,
attention_similarity: Tensor, attention_similarity: Tensor,
**kwargs: Unpack[TransformersKwargs],
): ):
# Self attention block # Self attention block
if self.skip_first_layer_pe: if self.skip_first_layer_pe:
queries = self.self_attn(query=queries, key=queries, value=queries) queries, _ = self.self_attn(query=queries, key=queries, value=queries)
else: else:
query = queries + query_point_embedding query = queries + query_point_embedding
attn_out = self.self_attn(query=query, key=query, value=queries) attn_out, _ = self.self_attn(query=query, key=query, value=queries)
queries = queries + attn_out queries = queries + attn_out
queries = self.layer_norm1(queries) queries = self.layer_norm1(queries)
@@ -344,7 +333,7 @@ class SamTwoWayAttentionBlock(nn.Module):
query = queries + query_point_embedding query = queries + query_point_embedding
key = keys + key_point_embedding key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image( attn_out, _ = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity query=query, key=key, value=keys, attention_similarity=attention_similarity
) )
queries = queries + attn_out queries = queries + attn_out
@@ -360,7 +349,7 @@ class SamTwoWayAttentionBlock(nn.Module):
query = queries + query_point_embedding query = queries + query_point_embedding
key = keys + key_point_embedding key = keys + key_point_embedding
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
keys = keys + attn_out keys = keys + attn_out
keys = self.layer_norm4(keys) keys = self.layer_norm4(keys)
@@ -378,7 +367,7 @@ class SamTwoWayTransformer(nn.Module):
for i in range(self.num_hidden_layers): for i in range(self.num_hidden_layers):
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config) self.final_attn_token_to_image = SamAttention(config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
def forward( def forward(
@@ -388,6 +377,7 @@ class SamTwoWayTransformer(nn.Module):
image_positional_embeddings: Tensor, image_positional_embeddings: Tensor,
attention_similarity: Tensor, attention_similarity: Tensor,
target_embedding=None, target_embedding=None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutput]: ) -> Union[tuple, BaseModelOutput]:
if image_embeddings is None: if image_embeddings is None:
raise ValueError("You have to specify an image_embedding") raise ValueError("You have to specify an image_embedding")
@@ -410,12 +400,13 @@ class SamTwoWayTransformer(nn.Module):
query_point_embedding=point_embeddings, query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings, key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity, attention_similarity=attention_similarity,
**kwargs,
) )
# Apply the final attenion layer from the points to the image # Apply the final attenion layer from the points to the image
query = queries + point_embeddings query = queries + point_embeddings
key = keys + image_positional_embeddings key = keys + image_positional_embeddings
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out queries = queries + attn_out
queries = self.layer_norm_final_attn(queries) queries = self.layer_norm_final_attn(queries)
@@ -501,12 +492,12 @@ class SamMaskDecoder(nn.Module):
Whether to return multiple masks or a single mask. Whether to return multiple masks or a single mask.
""" """
batch_size, num_channels, height, width = image_embeddings.shape batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1] point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
# Concatenate output tokens # Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
if sparse_prompt_embeddings.sum().item() != 0: if sparse_prompt_embeddings is not None:
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
else: else:
tokens = output_tokens tokens = output_tokens
@@ -611,7 +602,7 @@ class SamMaskEmbedding(nn.Module):
class SamPromptEncoder(nn.Module): class SamPromptEncoder(nn.Module):
def __init__(self, config: SamPromptEncoderConfig): def __init__(self, config: SamConfig):
super().__init__() super().__init__()
self.shared_embedding = SamPositionalEmbedding(config.vision_config) self.shared_embedding = SamPositionalEmbedding(config.vision_config)
config = config.prompt_encoder_config config = config.prompt_encoder_config
@@ -645,11 +636,7 @@ class SamPromptEncoder(nn.Module):
# This is required for the ONNX export. The dtype, device need to be explicitly # This is required for the ONNX export. The dtype, device need to be explicitly
# specified as otherwise torch.onnx.export interprets as double # specified as otherwise torch.onnx.export interprets as double
point_embedding = torch.where( point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
labels[..., None] != -10,
point_embedding,
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
)
point_embedding = torch.where( point_embedding = torch.where(
(labels == 0)[:, :, :, None], (labels == 0)[:, :, :, None],
@@ -696,9 +683,8 @@ class SamPromptEncoder(nn.Module):
""" """
sparse_embeddings = None sparse_embeddings = None
batch_size = 1 batch_size = 1
target_device = self.shared_embedding.positional_embedding.device
if input_points is not None: if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2] batch_size = input_points.shape[0]
if input_labels is None: if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.") raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
@@ -717,9 +703,6 @@ class SamPromptEncoder(nn.Module):
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
) )
if sparse_embeddings is None:
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
return sparse_embeddings, dense_embeddings return sparse_embeddings, dense_embeddings
@@ -1184,10 +1167,6 @@ class SamModel(SamPreTrainedModel):
Args: Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Input pixel values Input pixel values
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
""" """
vision_output = self.vision_encoder( vision_output = self.vision_encoder(
pixel_values, pixel_values,

View File

@@ -21,7 +21,7 @@
# limitations under the License. # limitations under the License.
import collections import collections
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Callable, Optional, Union
import numpy as np import numpy as np
import torch import torch
@@ -34,7 +34,7 @@ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring, logging from ...utils import auto_docstring, logging
from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig
@@ -601,6 +601,28 @@ class SamHQLayerNorm(nn.Module):
return x return x
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class SamHQAttention(nn.Module): class SamHQAttention(nn.Module):
""" """
SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
@@ -609,6 +631,7 @@ class SamHQAttention(nn.Module):
def __init__(self, config, downsample_rate=None): def __init__(self, config, downsample_rate=None):
super().__init__() super().__init__()
self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
@@ -617,12 +640,15 @@ class SamHQAttention(nn.Module):
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
if self.internal_dim % config.num_attention_heads != 0: if self.internal_dim % config.num_attention_heads != 0:
raise ValueError("num_attention_heads must divide hidden_size.") raise ValueError("num_attention_heads must divide hidden_size.")
self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
self.is_causal = False
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads c_per_head = channel // num_attention_heads
@@ -630,12 +656,16 @@ class SamHQAttention(nn.Module):
return hidden_states.transpose(1, 2) return hidden_states.transpose(1, 2)
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_heads, n_tokens, c_per_head = hidden_states.shape batch, n_tokens, n_heads, c_per_head = hidden_states.shape
hidden_states = hidden_states.transpose(1, 2)
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward( def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_similarity: Optional[Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Tensor: ) -> Tensor:
# Input projections # Input projections
query = self.q_proj(query) query = self.q_proj(query)
@@ -649,64 +679,26 @@ class SamHQAttention(nn.Module):
value = self._separate_heads(value, self.num_attention_heads) value = self._separate_heads(value, self.num_attention_heads)
# SamHQAttention # SamHQAttention
_, _, _, c_per_head = query.shape attention_interface: Callable = eager_attention_forward
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens if self.config._attn_implementation != "eager":
attn = attn / (c_per_head**0.5) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn = torch.softmax(attn, dim=-1)
if attention_similarity is not None: attn_output, attn_weights = attention_interface(
attn = attn + attention_similarity self,
attn = torch.softmax(attn, dim=-1) query,
key,
value,
attention_mask=attention_similarity,
dropout=0.0 if not self.training else self.dropout_p,
scaling=self.scaling,
is_causal=self.is_causal,
**kwargs,
)
# Get output attn_output = self._recombine_heads(attn_output, point_batch_size)
out = attn @ value attn_output = self.out_proj(attn_output)
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out return attn_output, attn_weights
class SamHQSdpaAttention(SamHQAttention):
"""
SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values. Using SDPA instead of the default attention.
"""
def __init__(self, config, downsample_rate=None):
super().__init__(config, downsample_rate)
def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)
# Scaled dot product attention
attn_mask = None
if attention_similarity is not None:
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
# Get output
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
SAM_HQ_ATTENTION_CLASSES = {
"eager": SamHQAttention,
"sdpa": SamHQSdpaAttention,
}
class SamHQTwoWayAttentionBlock(nn.Module): class SamHQTwoWayAttentionBlock(nn.Module):
@@ -729,21 +721,17 @@ class SamHQTwoWayAttentionBlock(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps self.layer_norm_eps = config.layer_norm_eps
self.self_attn = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1) self.self_attn = SamHQAttention(config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_token_to_image = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation]( self.cross_attn_token_to_image = SamHQAttention(config, downsample_rate=attention_downsample_rate)
config, downsample_rate=attention_downsample_rate
)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.mlp = SamHQMLPBlock(config) self.mlp = SamHQMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_image_to_token = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation]( self.cross_attn_image_to_token = SamHQAttention(config, downsample_rate=attention_downsample_rate)
config, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe self.skip_first_layer_pe = skip_first_layer_pe
def forward( def forward(
@@ -753,13 +741,14 @@ class SamHQTwoWayAttentionBlock(nn.Module):
query_point_embedding: Tensor, query_point_embedding: Tensor,
key_point_embedding: Tensor, key_point_embedding: Tensor,
attention_similarity: Tensor, attention_similarity: Tensor,
**kwargs: Unpack[TransformersKwargs],
): ):
# Self attention block # Self attention block
if self.skip_first_layer_pe: if self.skip_first_layer_pe:
queries = self.self_attn(query=queries, key=queries, value=queries) queries, _ = self.self_attn(query=queries, key=queries, value=queries)
else: else:
query = queries + query_point_embedding query = queries + query_point_embedding
attn_out = self.self_attn(query=query, key=query, value=queries) attn_out, _ = self.self_attn(query=query, key=query, value=queries)
queries = queries + attn_out queries = queries + attn_out
queries = self.layer_norm1(queries) queries = self.layer_norm1(queries)
@@ -767,7 +756,7 @@ class SamHQTwoWayAttentionBlock(nn.Module):
query = queries + query_point_embedding query = queries + query_point_embedding
key = keys + key_point_embedding key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image( attn_out, _ = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity query=query, key=key, value=keys, attention_similarity=attention_similarity
) )
queries = queries + attn_out queries = queries + attn_out
@@ -783,7 +772,7 @@ class SamHQTwoWayAttentionBlock(nn.Module):
query = queries + query_point_embedding query = queries + query_point_embedding
key = keys + key_point_embedding key = keys + key_point_embedding
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
keys = keys + attn_out keys = keys + attn_out
keys = self.layer_norm4(keys) keys = self.layer_norm4(keys)
@@ -801,7 +790,7 @@ class SamHQTwoWayTransformer(nn.Module):
for i in range(self.num_hidden_layers): for i in range(self.num_hidden_layers):
self.layers.append(SamHQTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) self.layers.append(SamHQTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
self.final_attn_token_to_image = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](config) self.final_attn_token_to_image = SamHQAttention(config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
def forward( def forward(
@@ -811,6 +800,7 @@ class SamHQTwoWayTransformer(nn.Module):
image_positional_embeddings: Tensor, image_positional_embeddings: Tensor,
attention_similarity: Tensor, attention_similarity: Tensor,
target_embedding=None, target_embedding=None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutput]: ) -> Union[tuple, BaseModelOutput]:
if image_embeddings is None: if image_embeddings is None:
raise ValueError("You have to specify an image_embedding") raise ValueError("You have to specify an image_embedding")
@@ -833,12 +823,13 @@ class SamHQTwoWayTransformer(nn.Module):
query_point_embedding=point_embeddings, query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings, key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity, attention_similarity=attention_similarity,
**kwargs,
) )
# Apply the final attenion layer from the points to the image # Apply the final attenion layer from the points to the image
query = queries + point_embeddings query = queries + point_embeddings
key = keys + image_positional_embeddings key = keys + image_positional_embeddings
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out queries = queries + attn_out
queries = self.layer_norm_final_attn(queries) queries = self.layer_norm_final_attn(queries)
@@ -957,7 +948,7 @@ class SamHQMaskDecoder(nn.Module):
- (Optional) A tuple containing attention tensors if output_attentions is True. - (Optional) A tuple containing attention tensors if output_attentions is True.
""" """
batch_size, num_channels, height, width = image_embeddings.shape batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1] point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0 has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0
@@ -980,7 +971,7 @@ class SamHQMaskDecoder(nn.Module):
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0) output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
if torch.any(sparse_prompt_embeddings != 0): if sparse_prompt_embeddings is not None:
tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2) tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2)
else: else:
tokens = output_tokens tokens = output_tokens
@@ -1147,7 +1138,7 @@ class SamHQMaskEmbedding(nn.Module):
class SamHQPromptEncoder(nn.Module): class SamHQPromptEncoder(nn.Module):
def __init__(self, config: SamHQPromptEncoderConfig): def __init__(self, config: SamHQConfig):
super().__init__() super().__init__()
self.shared_embedding = SamHQPositionalEmbedding(config.vision_config) self.shared_embedding = SamHQPositionalEmbedding(config.vision_config)
config = config.prompt_encoder_config config = config.prompt_encoder_config
@@ -1181,11 +1172,7 @@ class SamHQPromptEncoder(nn.Module):
# This is required for the ONNX export. The dtype, device need to be explicitly # This is required for the ONNX export. The dtype, device need to be explicitly
# specified as otherwise torch.onnx.export interprets as double # specified as otherwise torch.onnx.export interprets as double
point_embedding = torch.where( point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
labels[..., None] != -10,
point_embedding,
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
)
point_embedding = torch.where( point_embedding = torch.where(
(labels == 0)[:, :, :, None], (labels == 0)[:, :, :, None],
@@ -1232,9 +1219,8 @@ class SamHQPromptEncoder(nn.Module):
""" """
sparse_embeddings = None sparse_embeddings = None
batch_size = 1 batch_size = 1
target_device = self.shared_embedding.positional_embedding.device
if input_points is not None: if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2] batch_size = input_points.shape[0]
if input_labels is None: if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.") raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
@@ -1253,9 +1239,6 @@ class SamHQPromptEncoder(nn.Module):
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
) )
if sparse_embeddings is None:
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
return sparse_embeddings, dense_embeddings return sparse_embeddings, dense_embeddings
@@ -1517,8 +1500,8 @@ class SamHQModel(SamHQPreTrainedModel):
return SamHQImageSegmentationOutput( return SamHQImageSegmentationOutput(
iou_scores=mask_decoder_output[1], iou_scores=mask_decoder_output[1],
pred_masks=mask_decoder_output[0], pred_masks=mask_decoder_output[0],
vision_hidden_states=vision_outputs.hidden_states, vision_hidden_states=vision_outputs.hidden_states if pixel_values is not None else None,
vision_attentions=vision_outputs.attentions, vision_attentions=vision_outputs.attentions if pixel_values is not None else None,
) )

View File

@@ -327,7 +327,7 @@ class SamHQMaskDecoder(nn.Module):
- (Optional) A tuple containing attention tensors if output_attentions is True. - (Optional) A tuple containing attention tensors if output_attentions is True.
""" """
batch_size, num_channels, height, width = image_embeddings.shape batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1] point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0 has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0
@@ -350,7 +350,7 @@ class SamHQMaskDecoder(nn.Module):
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0) output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
if torch.any(sparse_prompt_embeddings != 0): if sparse_prompt_embeddings is not None:
tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2) tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2)
else: else:
tokens = output_tokens tokens = output_tokens
@@ -641,8 +641,8 @@ class SamHQModel(SamModel):
return SamHQImageSegmentationOutput( return SamHQImageSegmentationOutput(
iou_scores=mask_decoder_output[1], iou_scores=mask_decoder_output[1],
pred_masks=mask_decoder_output[0], pred_masks=mask_decoder_output[0],
vision_hidden_states=vision_outputs.hidden_states, vision_hidden_states=vision_outputs.hidden_states if pixel_values is not None else None,
vision_attentions=vision_outputs.attentions, vision_attentions=vision_outputs.attentions if pixel_values is not None else None,
) )

View File

@@ -806,7 +806,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
expectations = Expectations( expectations = Expectations(
{ {
(None, None): [-13.1695, -14.6201, -14.8989], (None, None): [-13.1695, -14.6201, -14.8989],
("cuda", 8): [-13.1668, -14.6182, -14.8970], ("cuda", 8): [-7.6769, -9.6935, -9.8773],
} }
) )
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device) EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)
@@ -831,9 +831,9 @@ class SamHQModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
scores = outputs.iou_scores.squeeze() scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3] masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9700), atol=2e-4)) torch.testing.assert_close(scores[-1], torch.tensor(0.9700).to(torch_device), atol=2e-4, rtol=2e-4)
self.assertTrue( torch.testing.assert_close(
torch.allclose(masks, torch.tensor([-29.9144, -30.0546, -30.9526]).to(torch_device), atol=3e-2) masks, torch.tensor([-9.2033, -8.5505, -7.1361]).to(torch_device), atol=3e-2, rtol=3e-2
) )
def test_inference_mask_generation_batched_points_batched_images(self): def test_inference_mask_generation_batched_points_batched_images(self):
@@ -895,7 +895,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
expectations = Expectations( expectations = Expectations(
{ {
(None, None): [-40.2445, -37.4300, -38.1577], (None, None): [-40.2445, -37.4300, -38.1577],
("cuda", 8): [-40.2351, -37.4334, -38.1526], ("cuda", 8): [-14.1195, -17.2663, -13.7805],
} }
) )
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device) EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)