Support for SDPA for SAM models (#34110)

* feat: add support for sdpa and gradient checkpointing

* fix: ruff format

* fix: config sdpa

* fix: sdpa layer naming convention

* fix: update test_eager_matches_sdpa_inference to handle vision_hidden_states

* test: skip incompatible tests and fix loading issue with sdpa

- Updated tests to skip cases flash and dynamic compile.
- Minor adjustment to ensure correct loading of model with sdpa for dispatch test.

* style: apply Ruff formatting

* ruff fix again after rebase

* [run-slow] sam

* [run-slow] sam

* refactor: Address review comments and improve sub-config handling in SAM model tests

- Added attributes for sub_configs as per PR #34410.
- Enabled tests for configs, ensuring the composite model (SAM) has several sub-configs in the main config.
- Added class attribute _is_composite=True to the tester class
- test_sdpa_can_dispatch_composite_models added

* [run-slow] sam

* style: ruff

* [run-slow] sam

* style: ruff again ...

* [run-slow] sam
This commit is contained in:
Magnus
2024-12-17 14:46:05 +01:00
committed by GitHub
parent 747f361da1
commit 6eb00dd2f0
4 changed files with 256 additions and 31 deletions

View File

@@ -246,6 +246,47 @@ class SamAttention(nn.Module):
return out
class SamSdpaAttention(SamAttention):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values. Using SDPA instead of the default attention.
"""
def __init__(self, config, downsample_rate=None):
super().__init__(config, downsample_rate)
def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)
# Scaled dot product attention
attn_mask = None
if attention_similarity is not None:
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
# Get output
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
SAM_ATTENTION_CLASSES = {
"eager": SamAttention,
"sdpa": SamSdpaAttention,
}
class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
@@ -266,18 +307,21 @@ class SamTwoWayAttentionBlock(nn.Module):
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps
self.self_attn = SamAttention(config, downsample_rate=1)
self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
@@ -344,7 +388,7 @@ class SamTwoWayTransformer(nn.Module):
for i in range(self.num_hidden_layers):
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
self.final_attn_token_to_image = SamAttention(config)
self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
def forward(
@@ -431,7 +475,7 @@ class SamFeedForward(nn.Module):
class SamMaskDecoder(nn.Module):
def __init__(self, config: SamMaskDecoderConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_multimask_outputs = config.num_multimask_outputs
@@ -856,11 +900,118 @@ class SamVisionAttention(nn.Module):
return outputs
class SamVisionSdpaAttention(SamVisionAttention):
"""
Multi-head Attention block with relative position embeddings.
Using SDPA instead of the default attention.
"""
def __init__(self, config, window_size):
super().__init__(config, window_size)
def add_decomposed_rel_pos(
self,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
This method is reimplemented to follow the implementation in:
https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950
This implementation is more memory efficient when using SDPA in the forward method.
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, dim = query.shape
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1)
rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width)
return rel_h, rel_w
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = (
self.qkv(hidden_states)
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
.permute(2, 0, 3, 1, 4)
)
# q, k, v with shape (B * nHead, H * W, C)
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
rel_h, rel_w = None, None
if self.use_rel_pos:
rel_h, rel_w = self.add_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
query = query.view(batch_size, self.num_attention_heads, height * width, -1)
key = key.view(batch_size, self.num_attention_heads, height * width, -1)
value = value.view(batch_size, self.num_attention_heads, height * width, -1)
if self.use_rel_pos:
rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
attn_bias = (rel_h + rel_w).view(
batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
)
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value)
attn_output = (
attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
.permute(0, 2, 3, 1, 4)
.reshape(batch_size, height, width, -1)
)
attn_output = self.proj(attn_output)
if output_attentions:
# For output_attentions, calculate the attention weights
attn_weights = (query @ key.transpose(-2, -1)) * self.scale
if attn_bias is not None:
attn_weights = attn_weights + attn_bias
attn_weights = F.softmax(attn_weights, dim=-1)
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
SAM_VISION_ATTENTION_CLASSES = {
"eager": SamVisionAttention,
"sdpa": SamVisionSdpaAttention,
}
class SamVisionLayer(nn.Module):
def __init__(self, config, window_size):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = SamVisionAttention(config, window_size)
self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.window_size = window_size
@@ -1071,6 +1222,8 @@ class SamPreTrainedModel(PreTrainedModel):
base_model_prefix = "sam"
main_input_name = "pixel_values"
_no_split_modules = ["SamVisionAttention"]
supports_gradient_checkpointing = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_range