[ONNX] Sam fix (#23110)
* [WIP] Fix for the ONNX export * Apply changes * Remove commented code * Resolve todo * empty -> zeros * fix slow tests --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com>
This commit is contained in:
@@ -223,9 +223,7 @@ class SamAttention(nn.Module):
|
|||||||
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
||||||
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
|
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
|
||||||
hidden_states = hidden_states.transpose(1, 2)
|
hidden_states = hidden_states.transpose(1, 2)
|
||||||
return hidden_states.reshape(
|
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
||||||
batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
|
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
|
||||||
# Input projections
|
# Input projections
|
||||||
@@ -482,7 +480,7 @@ class SamMaskDecoder(nn.Module):
|
|||||||
Whether or not to return the attentions tensors of all attention layers.
|
Whether or not to return the attentions tensors of all attention layers.
|
||||||
"""
|
"""
|
||||||
batch_size, num_channels, height, width = image_embeddings.shape
|
batch_size, num_channels, height, width = image_embeddings.shape
|
||||||
point_batch_size = max(1, sparse_prompt_embeddings.shape[1])
|
point_batch_size = sparse_prompt_embeddings.shape[1]
|
||||||
# Concatenate output tokens
|
# Concatenate output tokens
|
||||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||||
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
||||||
@@ -634,8 +632,18 @@ class SamPromptEncoder(nn.Module):
|
|||||||
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
|
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
|
||||||
)
|
)
|
||||||
|
|
||||||
point_embedding[labels == 0] += self.point_embed[0].weight
|
point_embedding = torch.where(
|
||||||
point_embedding[labels == 1] += self.point_embed[1].weight
|
(labels == 0)[:, :, :, None],
|
||||||
|
point_embedding + self.point_embed[0].weight[None, None, :, :],
|
||||||
|
point_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
point_embedding = torch.where(
|
||||||
|
(labels == 1)[:, :, :, None],
|
||||||
|
point_embedding + self.point_embed[1].weight[None, None, :, :],
|
||||||
|
point_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
return point_embedding
|
return point_embedding
|
||||||
|
|
||||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -675,8 +683,7 @@ class SamPromptEncoder(nn.Module):
|
|||||||
if input_labels is None:
|
if input_labels is None:
|
||||||
raise ValueError("If points are provided, labels must also be provided.")
|
raise ValueError("If points are provided, labels must also be provided.")
|
||||||
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
||||||
sparse_embeddings = torch.empty((batch_size, point_batch_size, 0, self.hidden_size), device=target_device)
|
sparse_embeddings = point_embeddings
|
||||||
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=2)
|
|
||||||
if input_boxes is not None:
|
if input_boxes is not None:
|
||||||
batch_size = input_boxes.shape[0]
|
batch_size = input_boxes.shape[0]
|
||||||
box_embeddings = self._embed_boxes(input_boxes)
|
box_embeddings = self._embed_boxes(input_boxes)
|
||||||
@@ -692,7 +699,7 @@ class SamPromptEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if sparse_embeddings is None:
|
if sparse_embeddings is None:
|
||||||
sparse_embeddings = torch.empty((batch_size, 0, 1, self.hidden_size), device=target_device)
|
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
|
||||||
|
|
||||||
return sparse_embeddings, dense_embeddings
|
return sparse_embeddings, dense_embeddings
|
||||||
|
|
||||||
@@ -742,8 +749,6 @@ class SamVisionAttention(nn.Module):
|
|||||||
Extracted positional embeddings according to relative positions.
|
Extracted positional embeddings according to relative positions.
|
||||||
"""
|
"""
|
||||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||||
# Interpolate rel pos if needed.
|
|
||||||
if rel_pos.shape[0] != max_rel_dist:
|
|
||||||
# Interpolate rel pos.
|
# Interpolate rel pos.
|
||||||
rel_pos_resized = F.interpolate(
|
rel_pos_resized = F.interpolate(
|
||||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||||
@@ -751,8 +756,6 @@ class SamVisionAttention(nn.Module):
|
|||||||
mode="linear",
|
mode="linear",
|
||||||
)
|
)
|
||||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||||
else:
|
|
||||||
rel_pos_resized = rel_pos
|
|
||||||
|
|
||||||
# Scale the coords with short length if shapes for q and k are different.
|
# Scale the coords with short length if shapes for q and k are different.
|
||||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||||
@@ -865,7 +868,6 @@ class SamVisionLayer(nn.Module):
|
|||||||
|
|
||||||
pad_h = (window_size - height % window_size) % window_size
|
pad_h = (window_size - height % window_size) % window_size
|
||||||
pad_w = (window_size - width % window_size) % window_size
|
pad_w = (window_size - width % window_size) % window_size
|
||||||
if pad_h > 0 or pad_w > 0:
|
|
||||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
|
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
|
||||||
pad_height, pad_width = height + pad_h, width + pad_w
|
pad_height, pad_width = height + pad_h, width + pad_w
|
||||||
|
|
||||||
@@ -902,7 +904,6 @@ class SamVisionLayer(nn.Module):
|
|||||||
hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
|
hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if pad_height > height or pad_width > width:
|
|
||||||
hidden_states = hidden_states[:, :height, :width, :].contiguous()
|
hidden_states = hidden_states[:, :height, :width, :].contiguous()
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user