From 433d2a23d734914e0a9903c6b7039d28daa6cbc1 Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Fri, 18 Jul 2025 18:46:27 -0400 Subject: [PATCH] Update SAM/SAM HQ attention implementation + fix Cuda sync issues (#39386) * update attention implementation and improve inference speed * modular sam_hq + fix integration tests on A10 * fixup * fix after review * softmax in correct place * return attn_weights in sam/sam_hq --- src/transformers/models/sam/modeling_sam.py | 159 ++++++++---------- .../models/sam_hq/modeling_sam_hq.py | 159 ++++++++---------- .../models/sam_hq/modular_sam_hq.py | 8 +- tests/models/sam_hq/test_modeling_sam_hq.py | 10 +- 4 files changed, 149 insertions(+), 187 deletions(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index b4446d68f5..8a607f237c 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -16,7 +16,7 @@ import collections from dataclasses import dataclass -from typing import Optional, Union +from typing import Callable, Optional, Union import numpy as np import torch @@ -28,7 +28,7 @@ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( ModelOutput, @@ -178,6 +178,28 @@ class SamLayerNorm(nn.Module): return x +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class SamAttention(nn.Module): """ SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and @@ -186,6 +208,7 @@ class SamAttention(nn.Module): def __init__(self, config, downsample_rate=None): super().__init__() + self.config = config self.hidden_size = config.hidden_size downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate @@ -194,12 +217,15 @@ class SamAttention(nn.Module): self.num_attention_heads = config.num_attention_heads if self.internal_dim % config.num_attention_heads != 0: raise ValueError("num_attention_heads must divide hidden_size.") + self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5 self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) + self.is_causal = False + def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: batch, point_batch_size, n_tokens, channel = hidden_states.shape c_per_head = channel // num_attention_heads @@ -207,12 +233,16 @@ class SamAttention(nn.Module): return hidden_states.transpose(1, 2) def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: - batch, n_heads, n_tokens, c_per_head = hidden_states.shape - hidden_states = hidden_states.transpose(1, 2) + batch, n_tokens, n_heads, c_per_head = hidden_states.shape return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) def forward( - self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None + self, + query: Tensor, + key: Tensor, + value: Tensor, + attention_similarity: Optional[Tensor] = None, + **kwargs: Unpack[TransformersKwargs], ) -> Tensor: # Input projections query = self.q_proj(query) @@ -226,64 +256,26 @@ class SamAttention(nn.Module): value = self._separate_heads(value, self.num_attention_heads) # SamAttention - _, _, _, c_per_head = query.shape - attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens - attn = attn / (c_per_head**0.5) - attn = torch.softmax(attn, dim=-1) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if attention_similarity is not None: - attn = attn + attention_similarity - attn = torch.softmax(attn, dim=-1) + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_similarity, + dropout=0.0 if not self.training else self.dropout_p, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) - # Get output - out = attn @ value - out = self._recombine_heads(out, point_batch_size) - out = self.out_proj(out) + attn_output = self._recombine_heads(attn_output, point_batch_size) + attn_output = self.out_proj(attn_output) - return out - - -class SamSdpaAttention(SamAttention): - """ - SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and - values. Using SDPA instead of the default attention. - """ - - def __init__(self, config, downsample_rate=None): - super().__init__(config, downsample_rate) - - def forward( - self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None - ) -> Tensor: - # Input projections - query = self.q_proj(query) - key = self.k_proj(key) - value = self.v_proj(value) - - point_batch_size = query.shape[1] - # Separate into heads - query = self._separate_heads(query, self.num_attention_heads) - key = self._separate_heads(key, self.num_attention_heads) - value = self._separate_heads(value, self.num_attention_heads) - - # Scaled dot product attention - attn_mask = None - if attention_similarity is not None: - attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1) - - out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask) - - # Get output - out = self._recombine_heads(out, point_batch_size) - out = self.out_proj(out) - - return out - - -SAM_ATTENTION_CLASSES = { - "eager": SamAttention, - "sdpa": SamSdpaAttention, -} + return attn_output, attn_weights class SamTwoWayAttentionBlock(nn.Module): @@ -306,21 +298,17 @@ class SamTwoWayAttentionBlock(nn.Module): self.hidden_size = config.hidden_size self.layer_norm_eps = config.layer_norm_eps - self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1) + self.self_attn = SamAttention(config, downsample_rate=1) self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) - self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation]( - config, downsample_rate=attention_downsample_rate - ) + self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate) self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.mlp = SamMLPBlock(config) self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) - self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation]( - config, downsample_rate=attention_downsample_rate - ) + self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate) self.skip_first_layer_pe = skip_first_layer_pe def forward( @@ -330,13 +318,14 @@ class SamTwoWayAttentionBlock(nn.Module): query_point_embedding: Tensor, key_point_embedding: Tensor, attention_similarity: Tensor, + **kwargs: Unpack[TransformersKwargs], ): # Self attention block if self.skip_first_layer_pe: - queries = self.self_attn(query=queries, key=queries, value=queries) + queries, _ = self.self_attn(query=queries, key=queries, value=queries) else: query = queries + query_point_embedding - attn_out = self.self_attn(query=query, key=query, value=queries) + attn_out, _ = self.self_attn(query=query, key=query, value=queries) queries = queries + attn_out queries = self.layer_norm1(queries) @@ -344,7 +333,7 @@ class SamTwoWayAttentionBlock(nn.Module): query = queries + query_point_embedding key = keys + key_point_embedding - attn_out = self.cross_attn_token_to_image( + attn_out, _ = self.cross_attn_token_to_image( query=query, key=key, value=keys, attention_similarity=attention_similarity ) queries = queries + attn_out @@ -360,7 +349,7 @@ class SamTwoWayAttentionBlock(nn.Module): query = queries + query_point_embedding key = keys + key_point_embedding - attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) + attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries) keys = keys + attn_out keys = self.layer_norm4(keys) @@ -378,7 +367,7 @@ class SamTwoWayTransformer(nn.Module): for i in range(self.num_hidden_layers): self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) - self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config) + self.final_attn_token_to_image = SamAttention(config) self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) def forward( @@ -388,6 +377,7 @@ class SamTwoWayTransformer(nn.Module): image_positional_embeddings: Tensor, attention_similarity: Tensor, target_embedding=None, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, BaseModelOutput]: if image_embeddings is None: raise ValueError("You have to specify an image_embedding") @@ -410,12 +400,13 @@ class SamTwoWayTransformer(nn.Module): query_point_embedding=point_embeddings, key_point_embedding=image_positional_embeddings, attention_similarity=attention_similarity, + **kwargs, ) # Apply the final attenion layer from the points to the image query = queries + point_embeddings key = keys + image_positional_embeddings - attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) + attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys) queries = queries + attn_out queries = self.layer_norm_final_attn(queries) @@ -501,12 +492,12 @@ class SamMaskDecoder(nn.Module): Whether to return multiple masks or a single mask. """ batch_size, num_channels, height, width = image_embeddings.shape - point_batch_size = sparse_prompt_embeddings.shape[1] + point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1 # Concatenate output tokens output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) - if sparse_prompt_embeddings.sum().item() != 0: + if sparse_prompt_embeddings is not None: tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) else: tokens = output_tokens @@ -611,7 +602,7 @@ class SamMaskEmbedding(nn.Module): class SamPromptEncoder(nn.Module): - def __init__(self, config: SamPromptEncoderConfig): + def __init__(self, config: SamConfig): super().__init__() self.shared_embedding = SamPositionalEmbedding(config.vision_config) config = config.prompt_encoder_config @@ -645,11 +636,7 @@ class SamPromptEncoder(nn.Module): # This is required for the ONNX export. The dtype, device need to be explicitly # specified as otherwise torch.onnx.export interprets as double - point_embedding = torch.where( - labels[..., None] != -10, - point_embedding, - torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device), - ) + point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding)) point_embedding = torch.where( (labels == 0)[:, :, :, None], @@ -696,9 +683,8 @@ class SamPromptEncoder(nn.Module): """ sparse_embeddings = None batch_size = 1 - target_device = self.shared_embedding.positional_embedding.device if input_points is not None: - batch_size, point_batch_size = input_points.shape[:2] + batch_size = input_points.shape[0] if input_labels is None: raise ValueError("If points are provided, labels must also be provided.") point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) @@ -717,9 +703,6 @@ class SamPromptEncoder(nn.Module): batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] ) - if sparse_embeddings is None: - sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device) - return sparse_embeddings, dense_embeddings @@ -1184,10 +1167,6 @@ class SamModel(SamPreTrainedModel): Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Input pixel values - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. """ vision_output = self.vision_encoder( pixel_values, diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index 6418e71720..288d4134d2 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -21,7 +21,7 @@ # limitations under the License. import collections from dataclasses import dataclass -from typing import Optional, Union +from typing import Callable, Optional, Union import numpy as np import torch @@ -34,7 +34,7 @@ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig @@ -601,6 +601,28 @@ class SamHQLayerNorm(nn.Module): return x +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class SamHQAttention(nn.Module): """ SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and @@ -609,6 +631,7 @@ class SamHQAttention(nn.Module): def __init__(self, config, downsample_rate=None): super().__init__() + self.config = config self.hidden_size = config.hidden_size downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate @@ -617,12 +640,15 @@ class SamHQAttention(nn.Module): self.num_attention_heads = config.num_attention_heads if self.internal_dim % config.num_attention_heads != 0: raise ValueError("num_attention_heads must divide hidden_size.") + self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5 self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) + self.is_causal = False + def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: batch, point_batch_size, n_tokens, channel = hidden_states.shape c_per_head = channel // num_attention_heads @@ -630,12 +656,16 @@ class SamHQAttention(nn.Module): return hidden_states.transpose(1, 2) def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: - batch, n_heads, n_tokens, c_per_head = hidden_states.shape - hidden_states = hidden_states.transpose(1, 2) + batch, n_tokens, n_heads, c_per_head = hidden_states.shape return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) def forward( - self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None + self, + query: Tensor, + key: Tensor, + value: Tensor, + attention_similarity: Optional[Tensor] = None, + **kwargs: Unpack[TransformersKwargs], ) -> Tensor: # Input projections query = self.q_proj(query) @@ -649,64 +679,26 @@ class SamHQAttention(nn.Module): value = self._separate_heads(value, self.num_attention_heads) # SamHQAttention - _, _, _, c_per_head = query.shape - attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens - attn = attn / (c_per_head**0.5) - attn = torch.softmax(attn, dim=-1) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if attention_similarity is not None: - attn = attn + attention_similarity - attn = torch.softmax(attn, dim=-1) + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_similarity, + dropout=0.0 if not self.training else self.dropout_p, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) - # Get output - out = attn @ value - out = self._recombine_heads(out, point_batch_size) - out = self.out_proj(out) + attn_output = self._recombine_heads(attn_output, point_batch_size) + attn_output = self.out_proj(attn_output) - return out - - -class SamHQSdpaAttention(SamHQAttention): - """ - SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and - values. Using SDPA instead of the default attention. - """ - - def __init__(self, config, downsample_rate=None): - super().__init__(config, downsample_rate) - - def forward( - self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None - ) -> Tensor: - # Input projections - query = self.q_proj(query) - key = self.k_proj(key) - value = self.v_proj(value) - - point_batch_size = query.shape[1] - # Separate into heads - query = self._separate_heads(query, self.num_attention_heads) - key = self._separate_heads(key, self.num_attention_heads) - value = self._separate_heads(value, self.num_attention_heads) - - # Scaled dot product attention - attn_mask = None - if attention_similarity is not None: - attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1) - - out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask) - - # Get output - out = self._recombine_heads(out, point_batch_size) - out = self.out_proj(out) - - return out - - -SAM_HQ_ATTENTION_CLASSES = { - "eager": SamHQAttention, - "sdpa": SamHQSdpaAttention, -} + return attn_output, attn_weights class SamHQTwoWayAttentionBlock(nn.Module): @@ -729,21 +721,17 @@ class SamHQTwoWayAttentionBlock(nn.Module): self.hidden_size = config.hidden_size self.layer_norm_eps = config.layer_norm_eps - self.self_attn = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1) + self.self_attn = SamHQAttention(config, downsample_rate=1) self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) - self.cross_attn_token_to_image = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation]( - config, downsample_rate=attention_downsample_rate - ) + self.cross_attn_token_to_image = SamHQAttention(config, downsample_rate=attention_downsample_rate) self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.mlp = SamHQMLPBlock(config) self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) - self.cross_attn_image_to_token = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation]( - config, downsample_rate=attention_downsample_rate - ) + self.cross_attn_image_to_token = SamHQAttention(config, downsample_rate=attention_downsample_rate) self.skip_first_layer_pe = skip_first_layer_pe def forward( @@ -753,13 +741,14 @@ class SamHQTwoWayAttentionBlock(nn.Module): query_point_embedding: Tensor, key_point_embedding: Tensor, attention_similarity: Tensor, + **kwargs: Unpack[TransformersKwargs], ): # Self attention block if self.skip_first_layer_pe: - queries = self.self_attn(query=queries, key=queries, value=queries) + queries, _ = self.self_attn(query=queries, key=queries, value=queries) else: query = queries + query_point_embedding - attn_out = self.self_attn(query=query, key=query, value=queries) + attn_out, _ = self.self_attn(query=query, key=query, value=queries) queries = queries + attn_out queries = self.layer_norm1(queries) @@ -767,7 +756,7 @@ class SamHQTwoWayAttentionBlock(nn.Module): query = queries + query_point_embedding key = keys + key_point_embedding - attn_out = self.cross_attn_token_to_image( + attn_out, _ = self.cross_attn_token_to_image( query=query, key=key, value=keys, attention_similarity=attention_similarity ) queries = queries + attn_out @@ -783,7 +772,7 @@ class SamHQTwoWayAttentionBlock(nn.Module): query = queries + query_point_embedding key = keys + key_point_embedding - attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) + attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries) keys = keys + attn_out keys = self.layer_norm4(keys) @@ -801,7 +790,7 @@ class SamHQTwoWayTransformer(nn.Module): for i in range(self.num_hidden_layers): self.layers.append(SamHQTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) - self.final_attn_token_to_image = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](config) + self.final_attn_token_to_image = SamHQAttention(config) self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) def forward( @@ -811,6 +800,7 @@ class SamHQTwoWayTransformer(nn.Module): image_positional_embeddings: Tensor, attention_similarity: Tensor, target_embedding=None, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, BaseModelOutput]: if image_embeddings is None: raise ValueError("You have to specify an image_embedding") @@ -833,12 +823,13 @@ class SamHQTwoWayTransformer(nn.Module): query_point_embedding=point_embeddings, key_point_embedding=image_positional_embeddings, attention_similarity=attention_similarity, + **kwargs, ) # Apply the final attenion layer from the points to the image query = queries + point_embeddings key = keys + image_positional_embeddings - attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) + attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys) queries = queries + attn_out queries = self.layer_norm_final_attn(queries) @@ -957,7 +948,7 @@ class SamHQMaskDecoder(nn.Module): - (Optional) A tuple containing attention tensors if output_attentions is True. """ batch_size, num_channels, height, width = image_embeddings.shape - point_batch_size = sparse_prompt_embeddings.shape[1] + point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1 has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0 @@ -980,7 +971,7 @@ class SamHQMaskDecoder(nn.Module): output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0) output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) - if torch.any(sparse_prompt_embeddings != 0): + if sparse_prompt_embeddings is not None: tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2) else: tokens = output_tokens @@ -1147,7 +1138,7 @@ class SamHQMaskEmbedding(nn.Module): class SamHQPromptEncoder(nn.Module): - def __init__(self, config: SamHQPromptEncoderConfig): + def __init__(self, config: SamHQConfig): super().__init__() self.shared_embedding = SamHQPositionalEmbedding(config.vision_config) config = config.prompt_encoder_config @@ -1181,11 +1172,7 @@ class SamHQPromptEncoder(nn.Module): # This is required for the ONNX export. The dtype, device need to be explicitly # specified as otherwise torch.onnx.export interprets as double - point_embedding = torch.where( - labels[..., None] != -10, - point_embedding, - torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device), - ) + point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding)) point_embedding = torch.where( (labels == 0)[:, :, :, None], @@ -1232,9 +1219,8 @@ class SamHQPromptEncoder(nn.Module): """ sparse_embeddings = None batch_size = 1 - target_device = self.shared_embedding.positional_embedding.device if input_points is not None: - batch_size, point_batch_size = input_points.shape[:2] + batch_size = input_points.shape[0] if input_labels is None: raise ValueError("If points are provided, labels must also be provided.") point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) @@ -1253,9 +1239,6 @@ class SamHQPromptEncoder(nn.Module): batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] ) - if sparse_embeddings is None: - sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device) - return sparse_embeddings, dense_embeddings @@ -1517,8 +1500,8 @@ class SamHQModel(SamHQPreTrainedModel): return SamHQImageSegmentationOutput( iou_scores=mask_decoder_output[1], pred_masks=mask_decoder_output[0], - vision_hidden_states=vision_outputs.hidden_states, - vision_attentions=vision_outputs.attentions, + vision_hidden_states=vision_outputs.hidden_states if pixel_values is not None else None, + vision_attentions=vision_outputs.attentions if pixel_values is not None else None, ) diff --git a/src/transformers/models/sam_hq/modular_sam_hq.py b/src/transformers/models/sam_hq/modular_sam_hq.py index 67772cb6c4..67399295c6 100644 --- a/src/transformers/models/sam_hq/modular_sam_hq.py +++ b/src/transformers/models/sam_hq/modular_sam_hq.py @@ -327,7 +327,7 @@ class SamHQMaskDecoder(nn.Module): - (Optional) A tuple containing attention tensors if output_attentions is True. """ batch_size, num_channels, height, width = image_embeddings.shape - point_batch_size = sparse_prompt_embeddings.shape[1] + point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1 has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0 @@ -350,7 +350,7 @@ class SamHQMaskDecoder(nn.Module): output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0) output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) - if torch.any(sparse_prompt_embeddings != 0): + if sparse_prompt_embeddings is not None: tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2) else: tokens = output_tokens @@ -641,8 +641,8 @@ class SamHQModel(SamModel): return SamHQImageSegmentationOutput( iou_scores=mask_decoder_output[1], pred_masks=mask_decoder_output[0], - vision_hidden_states=vision_outputs.hidden_states, - vision_attentions=vision_outputs.attentions, + vision_hidden_states=vision_outputs.hidden_states if pixel_values is not None else None, + vision_attentions=vision_outputs.attentions if pixel_values is not None else None, ) diff --git a/tests/models/sam_hq/test_modeling_sam_hq.py b/tests/models/sam_hq/test_modeling_sam_hq.py index d62ef664b9..192f7c8b02 100644 --- a/tests/models/sam_hq/test_modeling_sam_hq.py +++ b/tests/models/sam_hq/test_modeling_sam_hq.py @@ -806,7 +806,7 @@ class SamHQModelIntegrationTest(unittest.TestCase): expectations = Expectations( { (None, None): [-13.1695, -14.6201, -14.8989], - ("cuda", 8): [-13.1668, -14.6182, -14.8970], + ("cuda", 8): [-7.6769, -9.6935, -9.8773], } ) EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device) @@ -831,9 +831,9 @@ class SamHQModelIntegrationTest(unittest.TestCase): outputs = model(**inputs) scores = outputs.iou_scores.squeeze() masks = outputs.pred_masks[0, 0, 0, 0, :3] - self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9700), atol=2e-4)) - self.assertTrue( - torch.allclose(masks, torch.tensor([-29.9144, -30.0546, -30.9526]).to(torch_device), atol=3e-2) + torch.testing.assert_close(scores[-1], torch.tensor(0.9700).to(torch_device), atol=2e-4, rtol=2e-4) + torch.testing.assert_close( + masks, torch.tensor([-9.2033, -8.5505, -7.1361]).to(torch_device), atol=3e-2, rtol=3e-2 ) def test_inference_mask_generation_batched_points_batched_images(self): @@ -895,7 +895,7 @@ class SamHQModelIntegrationTest(unittest.TestCase): expectations = Expectations( { (None, None): [-40.2445, -37.4300, -38.1577], - ("cuda", 8): [-40.2351, -37.4334, -38.1526], + ("cuda", 8): [-14.1195, -17.2663, -13.7805], } ) EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)