Update SAM/SAM HQ attention implementation + fix Cuda sync issues (#39386)

* update attention implementation and improve inference speed

* modular sam_hq + fix integration tests on A10

* fixup

* fix after review

* softmax in correct place

* return attn_weights in sam/sam_hq
This commit is contained in:
Yoni Gozlan
2025-07-18 18:46:27 -04:00
committed by GitHub
parent 541bed22d6
commit 433d2a23d7
4 changed files with 149 additions and 187 deletions

View File

@@ -16,7 +16,7 @@
import collections
from dataclasses import dataclass
from typing import Optional, Union
from typing import Callable, Optional, Union
import numpy as np
import torch
@@ -28,7 +28,7 @@ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
ModelOutput,
@@ -178,6 +178,28 @@ class SamLayerNorm(nn.Module):
return x
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class SamAttention(nn.Module):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
@@ -186,6 +208,7 @@ class SamAttention(nn.Module):
def __init__(self, config, downsample_rate=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
@@ -194,12 +217,15 @@ class SamAttention(nn.Module):
self.num_attention_heads = config.num_attention_heads
if self.internal_dim % config.num_attention_heads != 0:
raise ValueError("num_attention_heads must divide hidden_size.")
self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
self.is_causal = False
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
@@ -207,12 +233,16 @@ class SamAttention(nn.Module):
return hidden_states.transpose(1, 2)
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
hidden_states = hidden_states.transpose(1, 2)
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_similarity: Optional[Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Tensor:
# Input projections
query = self.q_proj(query)
@@ -226,64 +256,26 @@ class SamAttention(nn.Module):
value = self._separate_heads(value, self.num_attention_heads)
# SamAttention
_, _, _, c_per_head = query.shape
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
attn = attn / (c_per_head**0.5)
attn = torch.softmax(attn, dim=-1)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if attention_similarity is not None:
attn = attn + attention_similarity
attn = torch.softmax(attn, dim=-1)
attn_output, attn_weights = attention_interface(
self,
query,
key,
value,
attention_mask=attention_similarity,
dropout=0.0 if not self.training else self.dropout_p,
scaling=self.scaling,
is_causal=self.is_causal,
**kwargs,
)
# Get output
out = attn @ value
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
attn_output = self._recombine_heads(attn_output, point_batch_size)
attn_output = self.out_proj(attn_output)
return out
class SamSdpaAttention(SamAttention):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values. Using SDPA instead of the default attention.
"""
def __init__(self, config, downsample_rate=None):
super().__init__(config, downsample_rate)
def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)
# Scaled dot product attention
attn_mask = None
if attention_similarity is not None:
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
# Get output
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
SAM_ATTENTION_CLASSES = {
"eager": SamAttention,
"sdpa": SamSdpaAttention,
}
return attn_output, attn_weights
class SamTwoWayAttentionBlock(nn.Module):
@@ -306,21 +298,17 @@ class SamTwoWayAttentionBlock(nn.Module):
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps
self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
self.self_attn = SamAttention(config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
@@ -330,13 +318,14 @@ class SamTwoWayAttentionBlock(nn.Module):
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
**kwargs: Unpack[TransformersKwargs],
):
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(query=queries, key=queries, value=queries)
queries, _ = self.self_attn(query=queries, key=queries, value=queries)
else:
query = queries + query_point_embedding
attn_out = self.self_attn(query=query, key=query, value=queries)
attn_out, _ = self.self_attn(query=query, key=query, value=queries)
queries = queries + attn_out
queries = self.layer_norm1(queries)
@@ -344,7 +333,7 @@ class SamTwoWayAttentionBlock(nn.Module):
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image(
attn_out, _ = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity
)
queries = queries + attn_out
@@ -360,7 +349,7 @@ class SamTwoWayAttentionBlock(nn.Module):
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
keys = keys + attn_out
keys = self.layer_norm4(keys)
@@ -378,7 +367,7 @@ class SamTwoWayTransformer(nn.Module):
for i in range(self.num_hidden_layers):
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
self.final_attn_token_to_image = SamAttention(config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
def forward(
@@ -388,6 +377,7 @@ class SamTwoWayTransformer(nn.Module):
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutput]:
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
@@ -410,12 +400,13 @@ class SamTwoWayTransformer(nn.Module):
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
**kwargs,
)
# Apply the final attenion layer from the points to the image
query = queries + point_embeddings
key = keys + image_positional_embeddings
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm_final_attn(queries)
@@ -501,12 +492,12 @@ class SamMaskDecoder(nn.Module):
Whether to return multiple masks or a single mask.
"""
batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1]
point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
if sparse_prompt_embeddings.sum().item() != 0:
if sparse_prompt_embeddings is not None:
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
else:
tokens = output_tokens
@@ -611,7 +602,7 @@ class SamMaskEmbedding(nn.Module):
class SamPromptEncoder(nn.Module):
def __init__(self, config: SamPromptEncoderConfig):
def __init__(self, config: SamConfig):
super().__init__()
self.shared_embedding = SamPositionalEmbedding(config.vision_config)
config = config.prompt_encoder_config
@@ -645,11 +636,7 @@ class SamPromptEncoder(nn.Module):
# This is required for the ONNX export. The dtype, device need to be explicitly
# specified as otherwise torch.onnx.export interprets as double
point_embedding = torch.where(
labels[..., None] != -10,
point_embedding,
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
)
point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
point_embedding = torch.where(
(labels == 0)[:, :, :, None],
@@ -696,9 +683,8 @@ class SamPromptEncoder(nn.Module):
"""
sparse_embeddings = None
batch_size = 1
target_device = self.shared_embedding.positional_embedding.device
if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2]
batch_size = input_points.shape[0]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
@@ -717,9 +703,6 @@ class SamPromptEncoder(nn.Module):
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
if sparse_embeddings is None:
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
return sparse_embeddings, dense_embeddings
@@ -1184,10 +1167,6 @@ class SamModel(SamPreTrainedModel):
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Input pixel values
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
"""
vision_output = self.vision_encoder(
pixel_values,

View File

@@ -21,7 +21,7 @@
# limitations under the License.
import collections
from dataclasses import dataclass
from typing import Optional, Union
from typing import Callable, Optional, Union
import numpy as np
import torch
@@ -34,7 +34,7 @@ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, logging
from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig
@@ -601,6 +601,28 @@ class SamHQLayerNorm(nn.Module):
return x
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class SamHQAttention(nn.Module):
"""
SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
@@ -609,6 +631,7 @@ class SamHQAttention(nn.Module):
def __init__(self, config, downsample_rate=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
@@ -617,12 +640,15 @@ class SamHQAttention(nn.Module):
self.num_attention_heads = config.num_attention_heads
if self.internal_dim % config.num_attention_heads != 0:
raise ValueError("num_attention_heads must divide hidden_size.")
self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
self.is_causal = False
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
@@ -630,12 +656,16 @@ class SamHQAttention(nn.Module):
return hidden_states.transpose(1, 2)
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
hidden_states = hidden_states.transpose(1, 2)
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_similarity: Optional[Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Tensor:
# Input projections
query = self.q_proj(query)
@@ -649,64 +679,26 @@ class SamHQAttention(nn.Module):
value = self._separate_heads(value, self.num_attention_heads)
# SamHQAttention
_, _, _, c_per_head = query.shape
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
attn = attn / (c_per_head**0.5)
attn = torch.softmax(attn, dim=-1)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if attention_similarity is not None:
attn = attn + attention_similarity
attn = torch.softmax(attn, dim=-1)
attn_output, attn_weights = attention_interface(
self,
query,
key,
value,
attention_mask=attention_similarity,
dropout=0.0 if not self.training else self.dropout_p,
scaling=self.scaling,
is_causal=self.is_causal,
**kwargs,
)
# Get output
out = attn @ value
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
attn_output = self._recombine_heads(attn_output, point_batch_size)
attn_output = self.out_proj(attn_output)
return out
class SamHQSdpaAttention(SamHQAttention):
"""
SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values. Using SDPA instead of the default attention.
"""
def __init__(self, config, downsample_rate=None):
super().__init__(config, downsample_rate)
def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)
# Scaled dot product attention
attn_mask = None
if attention_similarity is not None:
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
# Get output
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
SAM_HQ_ATTENTION_CLASSES = {
"eager": SamHQAttention,
"sdpa": SamHQSdpaAttention,
}
return attn_output, attn_weights
class SamHQTwoWayAttentionBlock(nn.Module):
@@ -729,21 +721,17 @@ class SamHQTwoWayAttentionBlock(nn.Module):
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps
self.self_attn = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
self.self_attn = SamHQAttention(config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_token_to_image = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.cross_attn_token_to_image = SamHQAttention(config, downsample_rate=attention_downsample_rate)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.mlp = SamHQMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_image_to_token = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.cross_attn_image_to_token = SamHQAttention(config, downsample_rate=attention_downsample_rate)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
@@ -753,13 +741,14 @@ class SamHQTwoWayAttentionBlock(nn.Module):
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
**kwargs: Unpack[TransformersKwargs],
):
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(query=queries, key=queries, value=queries)
queries, _ = self.self_attn(query=queries, key=queries, value=queries)
else:
query = queries + query_point_embedding
attn_out = self.self_attn(query=query, key=query, value=queries)
attn_out, _ = self.self_attn(query=query, key=query, value=queries)
queries = queries + attn_out
queries = self.layer_norm1(queries)
@@ -767,7 +756,7 @@ class SamHQTwoWayAttentionBlock(nn.Module):
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image(
attn_out, _ = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity
)
queries = queries + attn_out
@@ -783,7 +772,7 @@ class SamHQTwoWayAttentionBlock(nn.Module):
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
keys = keys + attn_out
keys = self.layer_norm4(keys)
@@ -801,7 +790,7 @@ class SamHQTwoWayTransformer(nn.Module):
for i in range(self.num_hidden_layers):
self.layers.append(SamHQTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
self.final_attn_token_to_image = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](config)
self.final_attn_token_to_image = SamHQAttention(config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
def forward(
@@ -811,6 +800,7 @@ class SamHQTwoWayTransformer(nn.Module):
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutput]:
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
@@ -833,12 +823,13 @@ class SamHQTwoWayTransformer(nn.Module):
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
**kwargs,
)
# Apply the final attenion layer from the points to the image
query = queries + point_embeddings
key = keys + image_positional_embeddings
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm_final_attn(queries)
@@ -957,7 +948,7 @@ class SamHQMaskDecoder(nn.Module):
- (Optional) A tuple containing attention tensors if output_attentions is True.
"""
batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1]
point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0
@@ -980,7 +971,7 @@ class SamHQMaskDecoder(nn.Module):
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
if torch.any(sparse_prompt_embeddings != 0):
if sparse_prompt_embeddings is not None:
tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2)
else:
tokens = output_tokens
@@ -1147,7 +1138,7 @@ class SamHQMaskEmbedding(nn.Module):
class SamHQPromptEncoder(nn.Module):
def __init__(self, config: SamHQPromptEncoderConfig):
def __init__(self, config: SamHQConfig):
super().__init__()
self.shared_embedding = SamHQPositionalEmbedding(config.vision_config)
config = config.prompt_encoder_config
@@ -1181,11 +1172,7 @@ class SamHQPromptEncoder(nn.Module):
# This is required for the ONNX export. The dtype, device need to be explicitly
# specified as otherwise torch.onnx.export interprets as double
point_embedding = torch.where(
labels[..., None] != -10,
point_embedding,
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
)
point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
point_embedding = torch.where(
(labels == 0)[:, :, :, None],
@@ -1232,9 +1219,8 @@ class SamHQPromptEncoder(nn.Module):
"""
sparse_embeddings = None
batch_size = 1
target_device = self.shared_embedding.positional_embedding.device
if input_points is not None:
batch_size, point_batch_size = input_points.shape[:2]
batch_size = input_points.shape[0]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
@@ -1253,9 +1239,6 @@ class SamHQPromptEncoder(nn.Module):
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
if sparse_embeddings is None:
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
return sparse_embeddings, dense_embeddings
@@ -1517,8 +1500,8 @@ class SamHQModel(SamHQPreTrainedModel):
return SamHQImageSegmentationOutput(
iou_scores=mask_decoder_output[1],
pred_masks=mask_decoder_output[0],
vision_hidden_states=vision_outputs.hidden_states,
vision_attentions=vision_outputs.attentions,
vision_hidden_states=vision_outputs.hidden_states if pixel_values is not None else None,
vision_attentions=vision_outputs.attentions if pixel_values is not None else None,
)

View File

@@ -327,7 +327,7 @@ class SamHQMaskDecoder(nn.Module):
- (Optional) A tuple containing attention tensors if output_attentions is True.
"""
batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1]
point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0
@@ -350,7 +350,7 @@ class SamHQMaskDecoder(nn.Module):
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
if torch.any(sparse_prompt_embeddings != 0):
if sparse_prompt_embeddings is not None:
tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2)
else:
tokens = output_tokens
@@ -641,8 +641,8 @@ class SamHQModel(SamModel):
return SamHQImageSegmentationOutput(
iou_scores=mask_decoder_output[1],
pred_masks=mask_decoder_output[0],
vision_hidden_states=vision_outputs.hidden_states,
vision_attentions=vision_outputs.attentions,
vision_hidden_states=vision_outputs.hidden_states if pixel_values is not None else None,
vision_attentions=vision_outputs.attentions if pixel_values is not None else None,
)

View File

@@ -806,7 +806,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
expectations = Expectations(
{
(None, None): [-13.1695, -14.6201, -14.8989],
("cuda", 8): [-13.1668, -14.6182, -14.8970],
("cuda", 8): [-7.6769, -9.6935, -9.8773],
}
)
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)
@@ -831,9 +831,9 @@ class SamHQModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs)
scores = outputs.iou_scores.squeeze()
masks = outputs.pred_masks[0, 0, 0, 0, :3]
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9700), atol=2e-4))
self.assertTrue(
torch.allclose(masks, torch.tensor([-29.9144, -30.0546, -30.9526]).to(torch_device), atol=3e-2)
torch.testing.assert_close(scores[-1], torch.tensor(0.9700).to(torch_device), atol=2e-4, rtol=2e-4)
torch.testing.assert_close(
masks, torch.tensor([-9.2033, -8.5505, -7.1361]).to(torch_device), atol=3e-2, rtol=3e-2
)
def test_inference_mask_generation_batched_points_batched_images(self):
@@ -895,7 +895,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
expectations = Expectations(
{
(None, None): [-40.2445, -37.4300, -38.1577],
("cuda", 8): [-40.2351, -37.4334, -38.1526],
("cuda", 8): [-14.1195, -17.2663, -13.7805],
}
)
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)