Update SAM/SAM HQ attention implementation + fix Cuda sync issues (#39386)
* update attention implementation and improve inference speed * modular sam_hq + fix integration tests on A10 * fixup * fix after review * softmax in correct place * return attn_weights in sam/sam_hq
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -28,7 +28,7 @@ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_layers import GradientCheckpointingLayer
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutput
|
from ...modeling_outputs import BaseModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@@ -178,6 +178,28 @@ class SamLayerNorm(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def eager_attention_forward(
|
||||||
|
module: nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
scaling: float,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class SamAttention(nn.Module):
|
class SamAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||||
@@ -186,6 +208,7 @@ class SamAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, config, downsample_rate=None):
|
def __init__(self, config, downsample_rate=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
|
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
|
||||||
@@ -194,12 +217,15 @@ class SamAttention(nn.Module):
|
|||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
if self.internal_dim % config.num_attention_heads != 0:
|
if self.internal_dim % config.num_attention_heads != 0:
|
||||||
raise ValueError("num_attention_heads must divide hidden_size.")
|
raise ValueError("num_attention_heads must divide hidden_size.")
|
||||||
|
self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
|
||||||
|
|
||||||
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
||||||
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
||||||
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
||||||
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
|
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
|
||||||
|
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
||||||
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
||||||
c_per_head = channel // num_attention_heads
|
c_per_head = channel // num_attention_heads
|
||||||
@@ -207,12 +233,16 @@ class SamAttention(nn.Module):
|
|||||||
return hidden_states.transpose(1, 2)
|
return hidden_states.transpose(1, 2)
|
||||||
|
|
||||||
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
||||||
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
|
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
|
||||||
hidden_states = hidden_states.transpose(1, 2)
|
|
||||||
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
attention_similarity: Optional[Tensor] = None,
|
||||||
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
# Input projections
|
# Input projections
|
||||||
query = self.q_proj(query)
|
query = self.q_proj(query)
|
||||||
@@ -226,64 +256,26 @@ class SamAttention(nn.Module):
|
|||||||
value = self._separate_heads(value, self.num_attention_heads)
|
value = self._separate_heads(value, self.num_attention_heads)
|
||||||
|
|
||||||
# SamAttention
|
# SamAttention
|
||||||
_, _, _, c_per_head = query.shape
|
attention_interface: Callable = eager_attention_forward
|
||||||
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
|
if self.config._attn_implementation != "eager":
|
||||||
attn = attn / (c_per_head**0.5)
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
attn = torch.softmax(attn, dim=-1)
|
|
||||||
|
|
||||||
if attention_similarity is not None:
|
attn_output, attn_weights = attention_interface(
|
||||||
attn = attn + attention_similarity
|
self,
|
||||||
attn = torch.softmax(attn, dim=-1)
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask=attention_similarity,
|
||||||
|
dropout=0.0 if not self.training else self.dropout_p,
|
||||||
|
scaling=self.scaling,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Get output
|
attn_output = self._recombine_heads(attn_output, point_batch_size)
|
||||||
out = attn @ value
|
attn_output = self.out_proj(attn_output)
|
||||||
out = self._recombine_heads(out, point_batch_size)
|
|
||||||
out = self.out_proj(out)
|
|
||||||
|
|
||||||
return out
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class SamSdpaAttention(SamAttention):
|
|
||||||
"""
|
|
||||||
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
|
||||||
values. Using SDPA instead of the default attention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, downsample_rate=None):
|
|
||||||
super().__init__(config, downsample_rate)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
|
|
||||||
) -> Tensor:
|
|
||||||
# Input projections
|
|
||||||
query = self.q_proj(query)
|
|
||||||
key = self.k_proj(key)
|
|
||||||
value = self.v_proj(value)
|
|
||||||
|
|
||||||
point_batch_size = query.shape[1]
|
|
||||||
# Separate into heads
|
|
||||||
query = self._separate_heads(query, self.num_attention_heads)
|
|
||||||
key = self._separate_heads(key, self.num_attention_heads)
|
|
||||||
value = self._separate_heads(value, self.num_attention_heads)
|
|
||||||
|
|
||||||
# Scaled dot product attention
|
|
||||||
attn_mask = None
|
|
||||||
if attention_similarity is not None:
|
|
||||||
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
|
|
||||||
|
|
||||||
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
|
|
||||||
|
|
||||||
# Get output
|
|
||||||
out = self._recombine_heads(out, point_batch_size)
|
|
||||||
out = self.out_proj(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
SAM_ATTENTION_CLASSES = {
|
|
||||||
"eager": SamAttention,
|
|
||||||
"sdpa": SamSdpaAttention,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SamTwoWayAttentionBlock(nn.Module):
|
class SamTwoWayAttentionBlock(nn.Module):
|
||||||
@@ -306,21 +298,17 @@ class SamTwoWayAttentionBlock(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.layer_norm_eps = config.layer_norm_eps
|
self.layer_norm_eps = config.layer_norm_eps
|
||||||
|
|
||||||
self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
|
self.self_attn = SamAttention(config, downsample_rate=1)
|
||||||
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
|
self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
|
||||||
config, downsample_rate=attention_downsample_rate
|
|
||||||
)
|
|
||||||
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
self.mlp = SamMLPBlock(config)
|
self.mlp = SamMLPBlock(config)
|
||||||
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
|
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
|
||||||
config, downsample_rate=attention_downsample_rate
|
|
||||||
)
|
|
||||||
self.skip_first_layer_pe = skip_first_layer_pe
|
self.skip_first_layer_pe = skip_first_layer_pe
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -330,13 +318,14 @@ class SamTwoWayAttentionBlock(nn.Module):
|
|||||||
query_point_embedding: Tensor,
|
query_point_embedding: Tensor,
|
||||||
key_point_embedding: Tensor,
|
key_point_embedding: Tensor,
|
||||||
attention_similarity: Tensor,
|
attention_similarity: Tensor,
|
||||||
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
):
|
):
|
||||||
# Self attention block
|
# Self attention block
|
||||||
if self.skip_first_layer_pe:
|
if self.skip_first_layer_pe:
|
||||||
queries = self.self_attn(query=queries, key=queries, value=queries)
|
queries, _ = self.self_attn(query=queries, key=queries, value=queries)
|
||||||
else:
|
else:
|
||||||
query = queries + query_point_embedding
|
query = queries + query_point_embedding
|
||||||
attn_out = self.self_attn(query=query, key=query, value=queries)
|
attn_out, _ = self.self_attn(query=query, key=query, value=queries)
|
||||||
queries = queries + attn_out
|
queries = queries + attn_out
|
||||||
queries = self.layer_norm1(queries)
|
queries = self.layer_norm1(queries)
|
||||||
|
|
||||||
@@ -344,7 +333,7 @@ class SamTwoWayAttentionBlock(nn.Module):
|
|||||||
query = queries + query_point_embedding
|
query = queries + query_point_embedding
|
||||||
key = keys + key_point_embedding
|
key = keys + key_point_embedding
|
||||||
|
|
||||||
attn_out = self.cross_attn_token_to_image(
|
attn_out, _ = self.cross_attn_token_to_image(
|
||||||
query=query, key=key, value=keys, attention_similarity=attention_similarity
|
query=query, key=key, value=keys, attention_similarity=attention_similarity
|
||||||
)
|
)
|
||||||
queries = queries + attn_out
|
queries = queries + attn_out
|
||||||
@@ -360,7 +349,7 @@ class SamTwoWayAttentionBlock(nn.Module):
|
|||||||
query = queries + query_point_embedding
|
query = queries + query_point_embedding
|
||||||
key = keys + key_point_embedding
|
key = keys + key_point_embedding
|
||||||
|
|
||||||
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
|
attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
|
||||||
keys = keys + attn_out
|
keys = keys + attn_out
|
||||||
|
|
||||||
keys = self.layer_norm4(keys)
|
keys = self.layer_norm4(keys)
|
||||||
@@ -378,7 +367,7 @@ class SamTwoWayTransformer(nn.Module):
|
|||||||
for i in range(self.num_hidden_layers):
|
for i in range(self.num_hidden_layers):
|
||||||
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
|
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
|
||||||
|
|
||||||
self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
|
self.final_attn_token_to_image = SamAttention(config)
|
||||||
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
|
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -388,6 +377,7 @@ class SamTwoWayTransformer(nn.Module):
|
|||||||
image_positional_embeddings: Tensor,
|
image_positional_embeddings: Tensor,
|
||||||
attention_similarity: Tensor,
|
attention_similarity: Tensor,
|
||||||
target_embedding=None,
|
target_embedding=None,
|
||||||
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> Union[tuple, BaseModelOutput]:
|
) -> Union[tuple, BaseModelOutput]:
|
||||||
if image_embeddings is None:
|
if image_embeddings is None:
|
||||||
raise ValueError("You have to specify an image_embedding")
|
raise ValueError("You have to specify an image_embedding")
|
||||||
@@ -410,12 +400,13 @@ class SamTwoWayTransformer(nn.Module):
|
|||||||
query_point_embedding=point_embeddings,
|
query_point_embedding=point_embeddings,
|
||||||
key_point_embedding=image_positional_embeddings,
|
key_point_embedding=image_positional_embeddings,
|
||||||
attention_similarity=attention_similarity,
|
attention_similarity=attention_similarity,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# Apply the final attenion layer from the points to the image
|
# Apply the final attenion layer from the points to the image
|
||||||
query = queries + point_embeddings
|
query = queries + point_embeddings
|
||||||
key = keys + image_positional_embeddings
|
key = keys + image_positional_embeddings
|
||||||
|
|
||||||
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
|
attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
|
||||||
|
|
||||||
queries = queries + attn_out
|
queries = queries + attn_out
|
||||||
queries = self.layer_norm_final_attn(queries)
|
queries = self.layer_norm_final_attn(queries)
|
||||||
@@ -501,12 +492,12 @@ class SamMaskDecoder(nn.Module):
|
|||||||
Whether to return multiple masks or a single mask.
|
Whether to return multiple masks or a single mask.
|
||||||
"""
|
"""
|
||||||
batch_size, num_channels, height, width = image_embeddings.shape
|
batch_size, num_channels, height, width = image_embeddings.shape
|
||||||
point_batch_size = sparse_prompt_embeddings.shape[1]
|
point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
|
||||||
# Concatenate output tokens
|
# Concatenate output tokens
|
||||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||||
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
||||||
|
|
||||||
if sparse_prompt_embeddings.sum().item() != 0:
|
if sparse_prompt_embeddings is not None:
|
||||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
|
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
|
||||||
else:
|
else:
|
||||||
tokens = output_tokens
|
tokens = output_tokens
|
||||||
@@ -611,7 +602,7 @@ class SamMaskEmbedding(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SamPromptEncoder(nn.Module):
|
class SamPromptEncoder(nn.Module):
|
||||||
def __init__(self, config: SamPromptEncoderConfig):
|
def __init__(self, config: SamConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shared_embedding = SamPositionalEmbedding(config.vision_config)
|
self.shared_embedding = SamPositionalEmbedding(config.vision_config)
|
||||||
config = config.prompt_encoder_config
|
config = config.prompt_encoder_config
|
||||||
@@ -645,11 +636,7 @@ class SamPromptEncoder(nn.Module):
|
|||||||
|
|
||||||
# This is required for the ONNX export. The dtype, device need to be explicitly
|
# This is required for the ONNX export. The dtype, device need to be explicitly
|
||||||
# specified as otherwise torch.onnx.export interprets as double
|
# specified as otherwise torch.onnx.export interprets as double
|
||||||
point_embedding = torch.where(
|
point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
|
||||||
labels[..., None] != -10,
|
|
||||||
point_embedding,
|
|
||||||
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
point_embedding = torch.where(
|
point_embedding = torch.where(
|
||||||
(labels == 0)[:, :, :, None],
|
(labels == 0)[:, :, :, None],
|
||||||
@@ -696,9 +683,8 @@ class SamPromptEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
sparse_embeddings = None
|
sparse_embeddings = None
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
target_device = self.shared_embedding.positional_embedding.device
|
|
||||||
if input_points is not None:
|
if input_points is not None:
|
||||||
batch_size, point_batch_size = input_points.shape[:2]
|
batch_size = input_points.shape[0]
|
||||||
if input_labels is None:
|
if input_labels is None:
|
||||||
raise ValueError("If points are provided, labels must also be provided.")
|
raise ValueError("If points are provided, labels must also be provided.")
|
||||||
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
||||||
@@ -717,9 +703,6 @@ class SamPromptEncoder(nn.Module):
|
|||||||
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if sparse_embeddings is None:
|
|
||||||
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
|
|
||||||
|
|
||||||
return sparse_embeddings, dense_embeddings
|
return sparse_embeddings, dense_embeddings
|
||||||
|
|
||||||
|
|
||||||
@@ -1184,10 +1167,6 @@ class SamModel(SamPreTrainedModel):
|
|||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
Input pixel values
|
Input pixel values
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers.
|
|
||||||
"""
|
"""
|
||||||
vision_output = self.vision_encoder(
|
vision_output = self.vision_encoder(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections
|
import collections
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -34,7 +34,7 @@ from transformers.utils.generic import OutputRecorder, TransformersKwargs, check
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_layers import GradientCheckpointingLayer
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutput
|
from ...modeling_outputs import BaseModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import auto_docstring, logging
|
from ...utils import auto_docstring, logging
|
||||||
from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig
|
from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig
|
||||||
@@ -601,6 +601,28 @@ class SamHQLayerNorm(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def eager_attention_forward(
|
||||||
|
module: nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
scaling: float,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class SamHQAttention(nn.Module):
|
class SamHQAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||||
@@ -609,6 +631,7 @@ class SamHQAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, config, downsample_rate=None):
|
def __init__(self, config, downsample_rate=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
|
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
|
||||||
@@ -617,12 +640,15 @@ class SamHQAttention(nn.Module):
|
|||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
if self.internal_dim % config.num_attention_heads != 0:
|
if self.internal_dim % config.num_attention_heads != 0:
|
||||||
raise ValueError("num_attention_heads must divide hidden_size.")
|
raise ValueError("num_attention_heads must divide hidden_size.")
|
||||||
|
self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
|
||||||
|
|
||||||
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
||||||
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
||||||
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
|
||||||
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
|
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
|
||||||
|
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
||||||
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
||||||
c_per_head = channel // num_attention_heads
|
c_per_head = channel // num_attention_heads
|
||||||
@@ -630,12 +656,16 @@ class SamHQAttention(nn.Module):
|
|||||||
return hidden_states.transpose(1, 2)
|
return hidden_states.transpose(1, 2)
|
||||||
|
|
||||||
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
||||||
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
|
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
|
||||||
hidden_states = hidden_states.transpose(1, 2)
|
|
||||||
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
attention_similarity: Optional[Tensor] = None,
|
||||||
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
# Input projections
|
# Input projections
|
||||||
query = self.q_proj(query)
|
query = self.q_proj(query)
|
||||||
@@ -649,64 +679,26 @@ class SamHQAttention(nn.Module):
|
|||||||
value = self._separate_heads(value, self.num_attention_heads)
|
value = self._separate_heads(value, self.num_attention_heads)
|
||||||
|
|
||||||
# SamHQAttention
|
# SamHQAttention
|
||||||
_, _, _, c_per_head = query.shape
|
attention_interface: Callable = eager_attention_forward
|
||||||
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
|
if self.config._attn_implementation != "eager":
|
||||||
attn = attn / (c_per_head**0.5)
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
attn = torch.softmax(attn, dim=-1)
|
|
||||||
|
|
||||||
if attention_similarity is not None:
|
attn_output, attn_weights = attention_interface(
|
||||||
attn = attn + attention_similarity
|
self,
|
||||||
attn = torch.softmax(attn, dim=-1)
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask=attention_similarity,
|
||||||
|
dropout=0.0 if not self.training else self.dropout_p,
|
||||||
|
scaling=self.scaling,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Get output
|
attn_output = self._recombine_heads(attn_output, point_batch_size)
|
||||||
out = attn @ value
|
attn_output = self.out_proj(attn_output)
|
||||||
out = self._recombine_heads(out, point_batch_size)
|
|
||||||
out = self.out_proj(out)
|
|
||||||
|
|
||||||
return out
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class SamHQSdpaAttention(SamHQAttention):
|
|
||||||
"""
|
|
||||||
SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
|
||||||
values. Using SDPA instead of the default attention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, downsample_rate=None):
|
|
||||||
super().__init__(config, downsample_rate)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
|
|
||||||
) -> Tensor:
|
|
||||||
# Input projections
|
|
||||||
query = self.q_proj(query)
|
|
||||||
key = self.k_proj(key)
|
|
||||||
value = self.v_proj(value)
|
|
||||||
|
|
||||||
point_batch_size = query.shape[1]
|
|
||||||
# Separate into heads
|
|
||||||
query = self._separate_heads(query, self.num_attention_heads)
|
|
||||||
key = self._separate_heads(key, self.num_attention_heads)
|
|
||||||
value = self._separate_heads(value, self.num_attention_heads)
|
|
||||||
|
|
||||||
# Scaled dot product attention
|
|
||||||
attn_mask = None
|
|
||||||
if attention_similarity is not None:
|
|
||||||
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
|
|
||||||
|
|
||||||
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
|
|
||||||
|
|
||||||
# Get output
|
|
||||||
out = self._recombine_heads(out, point_batch_size)
|
|
||||||
out = self.out_proj(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
SAM_HQ_ATTENTION_CLASSES = {
|
|
||||||
"eager": SamHQAttention,
|
|
||||||
"sdpa": SamHQSdpaAttention,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SamHQTwoWayAttentionBlock(nn.Module):
|
class SamHQTwoWayAttentionBlock(nn.Module):
|
||||||
@@ -729,21 +721,17 @@ class SamHQTwoWayAttentionBlock(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.layer_norm_eps = config.layer_norm_eps
|
self.layer_norm_eps = config.layer_norm_eps
|
||||||
|
|
||||||
self.self_attn = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
|
self.self_attn = SamHQAttention(config, downsample_rate=1)
|
||||||
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
self.cross_attn_token_to_image = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](
|
self.cross_attn_token_to_image = SamHQAttention(config, downsample_rate=attention_downsample_rate)
|
||||||
config, downsample_rate=attention_downsample_rate
|
|
||||||
)
|
|
||||||
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
self.mlp = SamHQMLPBlock(config)
|
self.mlp = SamHQMLPBlock(config)
|
||||||
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
self.cross_attn_image_to_token = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](
|
self.cross_attn_image_to_token = SamHQAttention(config, downsample_rate=attention_downsample_rate)
|
||||||
config, downsample_rate=attention_downsample_rate
|
|
||||||
)
|
|
||||||
self.skip_first_layer_pe = skip_first_layer_pe
|
self.skip_first_layer_pe = skip_first_layer_pe
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -753,13 +741,14 @@ class SamHQTwoWayAttentionBlock(nn.Module):
|
|||||||
query_point_embedding: Tensor,
|
query_point_embedding: Tensor,
|
||||||
key_point_embedding: Tensor,
|
key_point_embedding: Tensor,
|
||||||
attention_similarity: Tensor,
|
attention_similarity: Tensor,
|
||||||
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
):
|
):
|
||||||
# Self attention block
|
# Self attention block
|
||||||
if self.skip_first_layer_pe:
|
if self.skip_first_layer_pe:
|
||||||
queries = self.self_attn(query=queries, key=queries, value=queries)
|
queries, _ = self.self_attn(query=queries, key=queries, value=queries)
|
||||||
else:
|
else:
|
||||||
query = queries + query_point_embedding
|
query = queries + query_point_embedding
|
||||||
attn_out = self.self_attn(query=query, key=query, value=queries)
|
attn_out, _ = self.self_attn(query=query, key=query, value=queries)
|
||||||
queries = queries + attn_out
|
queries = queries + attn_out
|
||||||
queries = self.layer_norm1(queries)
|
queries = self.layer_norm1(queries)
|
||||||
|
|
||||||
@@ -767,7 +756,7 @@ class SamHQTwoWayAttentionBlock(nn.Module):
|
|||||||
query = queries + query_point_embedding
|
query = queries + query_point_embedding
|
||||||
key = keys + key_point_embedding
|
key = keys + key_point_embedding
|
||||||
|
|
||||||
attn_out = self.cross_attn_token_to_image(
|
attn_out, _ = self.cross_attn_token_to_image(
|
||||||
query=query, key=key, value=keys, attention_similarity=attention_similarity
|
query=query, key=key, value=keys, attention_similarity=attention_similarity
|
||||||
)
|
)
|
||||||
queries = queries + attn_out
|
queries = queries + attn_out
|
||||||
@@ -783,7 +772,7 @@ class SamHQTwoWayAttentionBlock(nn.Module):
|
|||||||
query = queries + query_point_embedding
|
query = queries + query_point_embedding
|
||||||
key = keys + key_point_embedding
|
key = keys + key_point_embedding
|
||||||
|
|
||||||
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
|
attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
|
||||||
keys = keys + attn_out
|
keys = keys + attn_out
|
||||||
|
|
||||||
keys = self.layer_norm4(keys)
|
keys = self.layer_norm4(keys)
|
||||||
@@ -801,7 +790,7 @@ class SamHQTwoWayTransformer(nn.Module):
|
|||||||
for i in range(self.num_hidden_layers):
|
for i in range(self.num_hidden_layers):
|
||||||
self.layers.append(SamHQTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
|
self.layers.append(SamHQTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
|
||||||
|
|
||||||
self.final_attn_token_to_image = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](config)
|
self.final_attn_token_to_image = SamHQAttention(config)
|
||||||
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
|
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -811,6 +800,7 @@ class SamHQTwoWayTransformer(nn.Module):
|
|||||||
image_positional_embeddings: Tensor,
|
image_positional_embeddings: Tensor,
|
||||||
attention_similarity: Tensor,
|
attention_similarity: Tensor,
|
||||||
target_embedding=None,
|
target_embedding=None,
|
||||||
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> Union[tuple, BaseModelOutput]:
|
) -> Union[tuple, BaseModelOutput]:
|
||||||
if image_embeddings is None:
|
if image_embeddings is None:
|
||||||
raise ValueError("You have to specify an image_embedding")
|
raise ValueError("You have to specify an image_embedding")
|
||||||
@@ -833,12 +823,13 @@ class SamHQTwoWayTransformer(nn.Module):
|
|||||||
query_point_embedding=point_embeddings,
|
query_point_embedding=point_embeddings,
|
||||||
key_point_embedding=image_positional_embeddings,
|
key_point_embedding=image_positional_embeddings,
|
||||||
attention_similarity=attention_similarity,
|
attention_similarity=attention_similarity,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# Apply the final attenion layer from the points to the image
|
# Apply the final attenion layer from the points to the image
|
||||||
query = queries + point_embeddings
|
query = queries + point_embeddings
|
||||||
key = keys + image_positional_embeddings
|
key = keys + image_positional_embeddings
|
||||||
|
|
||||||
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
|
attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
|
||||||
|
|
||||||
queries = queries + attn_out
|
queries = queries + attn_out
|
||||||
queries = self.layer_norm_final_attn(queries)
|
queries = self.layer_norm_final_attn(queries)
|
||||||
@@ -957,7 +948,7 @@ class SamHQMaskDecoder(nn.Module):
|
|||||||
- (Optional) A tuple containing attention tensors if output_attentions is True.
|
- (Optional) A tuple containing attention tensors if output_attentions is True.
|
||||||
"""
|
"""
|
||||||
batch_size, num_channels, height, width = image_embeddings.shape
|
batch_size, num_channels, height, width = image_embeddings.shape
|
||||||
point_batch_size = sparse_prompt_embeddings.shape[1]
|
point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
|
||||||
|
|
||||||
has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0
|
has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0
|
||||||
|
|
||||||
@@ -980,7 +971,7 @@ class SamHQMaskDecoder(nn.Module):
|
|||||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0)
|
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0)
|
||||||
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
||||||
|
|
||||||
if torch.any(sparse_prompt_embeddings != 0):
|
if sparse_prompt_embeddings is not None:
|
||||||
tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2)
|
tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2)
|
||||||
else:
|
else:
|
||||||
tokens = output_tokens
|
tokens = output_tokens
|
||||||
@@ -1147,7 +1138,7 @@ class SamHQMaskEmbedding(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SamHQPromptEncoder(nn.Module):
|
class SamHQPromptEncoder(nn.Module):
|
||||||
def __init__(self, config: SamHQPromptEncoderConfig):
|
def __init__(self, config: SamHQConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shared_embedding = SamHQPositionalEmbedding(config.vision_config)
|
self.shared_embedding = SamHQPositionalEmbedding(config.vision_config)
|
||||||
config = config.prompt_encoder_config
|
config = config.prompt_encoder_config
|
||||||
@@ -1181,11 +1172,7 @@ class SamHQPromptEncoder(nn.Module):
|
|||||||
|
|
||||||
# This is required for the ONNX export. The dtype, device need to be explicitly
|
# This is required for the ONNX export. The dtype, device need to be explicitly
|
||||||
# specified as otherwise torch.onnx.export interprets as double
|
# specified as otherwise torch.onnx.export interprets as double
|
||||||
point_embedding = torch.where(
|
point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
|
||||||
labels[..., None] != -10,
|
|
||||||
point_embedding,
|
|
||||||
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
point_embedding = torch.where(
|
point_embedding = torch.where(
|
||||||
(labels == 0)[:, :, :, None],
|
(labels == 0)[:, :, :, None],
|
||||||
@@ -1232,9 +1219,8 @@ class SamHQPromptEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
sparse_embeddings = None
|
sparse_embeddings = None
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
target_device = self.shared_embedding.positional_embedding.device
|
|
||||||
if input_points is not None:
|
if input_points is not None:
|
||||||
batch_size, point_batch_size = input_points.shape[:2]
|
batch_size = input_points.shape[0]
|
||||||
if input_labels is None:
|
if input_labels is None:
|
||||||
raise ValueError("If points are provided, labels must also be provided.")
|
raise ValueError("If points are provided, labels must also be provided.")
|
||||||
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
||||||
@@ -1253,9 +1239,6 @@ class SamHQPromptEncoder(nn.Module):
|
|||||||
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if sparse_embeddings is None:
|
|
||||||
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
|
|
||||||
|
|
||||||
return sparse_embeddings, dense_embeddings
|
return sparse_embeddings, dense_embeddings
|
||||||
|
|
||||||
|
|
||||||
@@ -1517,8 +1500,8 @@ class SamHQModel(SamHQPreTrainedModel):
|
|||||||
return SamHQImageSegmentationOutput(
|
return SamHQImageSegmentationOutput(
|
||||||
iou_scores=mask_decoder_output[1],
|
iou_scores=mask_decoder_output[1],
|
||||||
pred_masks=mask_decoder_output[0],
|
pred_masks=mask_decoder_output[0],
|
||||||
vision_hidden_states=vision_outputs.hidden_states,
|
vision_hidden_states=vision_outputs.hidden_states if pixel_values is not None else None,
|
||||||
vision_attentions=vision_outputs.attentions,
|
vision_attentions=vision_outputs.attentions if pixel_values is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -327,7 +327,7 @@ class SamHQMaskDecoder(nn.Module):
|
|||||||
- (Optional) A tuple containing attention tensors if output_attentions is True.
|
- (Optional) A tuple containing attention tensors if output_attentions is True.
|
||||||
"""
|
"""
|
||||||
batch_size, num_channels, height, width = image_embeddings.shape
|
batch_size, num_channels, height, width = image_embeddings.shape
|
||||||
point_batch_size = sparse_prompt_embeddings.shape[1]
|
point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
|
||||||
|
|
||||||
has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0
|
has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class SamHQMaskDecoder(nn.Module):
|
|||||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0)
|
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0)
|
||||||
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
||||||
|
|
||||||
if torch.any(sparse_prompt_embeddings != 0):
|
if sparse_prompt_embeddings is not None:
|
||||||
tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2)
|
tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2)
|
||||||
else:
|
else:
|
||||||
tokens = output_tokens
|
tokens = output_tokens
|
||||||
@@ -641,8 +641,8 @@ class SamHQModel(SamModel):
|
|||||||
return SamHQImageSegmentationOutput(
|
return SamHQImageSegmentationOutput(
|
||||||
iou_scores=mask_decoder_output[1],
|
iou_scores=mask_decoder_output[1],
|
||||||
pred_masks=mask_decoder_output[0],
|
pred_masks=mask_decoder_output[0],
|
||||||
vision_hidden_states=vision_outputs.hidden_states,
|
vision_hidden_states=vision_outputs.hidden_states if pixel_values is not None else None,
|
||||||
vision_attentions=vision_outputs.attentions,
|
vision_attentions=vision_outputs.attentions if pixel_values is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -806,7 +806,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
expectations = Expectations(
|
expectations = Expectations(
|
||||||
{
|
{
|
||||||
(None, None): [-13.1695, -14.6201, -14.8989],
|
(None, None): [-13.1695, -14.6201, -14.8989],
|
||||||
("cuda", 8): [-13.1668, -14.6182, -14.8970],
|
("cuda", 8): [-7.6769, -9.6935, -9.8773],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)
|
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||||
@@ -831,9 +831,9 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
scores = outputs.iou_scores.squeeze()
|
scores = outputs.iou_scores.squeeze()
|
||||||
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
masks = outputs.pred_masks[0, 0, 0, 0, :3]
|
||||||
self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9700), atol=2e-4))
|
torch.testing.assert_close(scores[-1], torch.tensor(0.9700).to(torch_device), atol=2e-4, rtol=2e-4)
|
||||||
self.assertTrue(
|
torch.testing.assert_close(
|
||||||
torch.allclose(masks, torch.tensor([-29.9144, -30.0546, -30.9526]).to(torch_device), atol=3e-2)
|
masks, torch.tensor([-9.2033, -8.5505, -7.1361]).to(torch_device), atol=3e-2, rtol=3e-2
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_inference_mask_generation_batched_points_batched_images(self):
|
def test_inference_mask_generation_batched_points_batched_images(self):
|
||||||
@@ -895,7 +895,7 @@ class SamHQModelIntegrationTest(unittest.TestCase):
|
|||||||
expectations = Expectations(
|
expectations = Expectations(
|
||||||
{
|
{
|
||||||
(None, None): [-40.2445, -37.4300, -38.1577],
|
(None, None): [-40.2445, -37.4300, -38.1577],
|
||||||
("cuda", 8): [-40.2351, -37.4334, -38.1526],
|
("cuda", 8): [-14.1195, -17.2663, -13.7805],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)
|
EXPECTED_MASKS = torch.tensor(expectations.get_expectation()).to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user