[ONNX] Sam fix (#23110)
* [WIP] Fix for the ONNX export * Apply changes * Remove commented code * Resolve todo * empty -> zeros * fix slow tests --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com>
This commit is contained in:
@@ -223,9 +223,7 @@ class SamAttention(nn.Module):
|
||||
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
||||
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
return hidden_states.reshape(
|
||||
batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head
|
||||
)
|
||||
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
||||
|
||||
def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
|
||||
# Input projections
|
||||
@@ -482,7 +480,7 @@ class SamMaskDecoder(nn.Module):
|
||||
Whether or not to return the attentions tensors of all attention layers.
|
||||
"""
|
||||
batch_size, num_channels, height, width = image_embeddings.shape
|
||||
point_batch_size = max(1, sparse_prompt_embeddings.shape[1])
|
||||
point_batch_size = sparse_prompt_embeddings.shape[1]
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
||||
@@ -634,8 +632,18 @@ class SamPromptEncoder(nn.Module):
|
||||
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
|
||||
)
|
||||
|
||||
point_embedding[labels == 0] += self.point_embed[0].weight
|
||||
point_embedding[labels == 1] += self.point_embed[1].weight
|
||||
point_embedding = torch.where(
|
||||
(labels == 0)[:, :, :, None],
|
||||
point_embedding + self.point_embed[0].weight[None, None, :, :],
|
||||
point_embedding,
|
||||
)
|
||||
|
||||
point_embedding = torch.where(
|
||||
(labels == 1)[:, :, :, None],
|
||||
point_embedding + self.point_embed[1].weight[None, None, :, :],
|
||||
point_embedding,
|
||||
)
|
||||
|
||||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
@@ -675,8 +683,7 @@ class SamPromptEncoder(nn.Module):
|
||||
if input_labels is None:
|
||||
raise ValueError("If points are provided, labels must also be provided.")
|
||||
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
|
||||
sparse_embeddings = torch.empty((batch_size, point_batch_size, 0, self.hidden_size), device=target_device)
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=2)
|
||||
sparse_embeddings = point_embeddings
|
||||
if input_boxes is not None:
|
||||
batch_size = input_boxes.shape[0]
|
||||
box_embeddings = self._embed_boxes(input_boxes)
|
||||
@@ -692,7 +699,7 @@ class SamPromptEncoder(nn.Module):
|
||||
)
|
||||
|
||||
if sparse_embeddings is None:
|
||||
sparse_embeddings = torch.empty((batch_size, 0, 1, self.hidden_size), device=target_device)
|
||||
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
|
||||
|
||||
return sparse_embeddings, dense_embeddings
|
||||
|
||||
@@ -742,17 +749,13 @@ class SamVisionAttention(nn.Module):
|
||||
Extracted positional embeddings according to relative positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
@@ -865,8 +868,7 @@ class SamVisionLayer(nn.Module):
|
||||
|
||||
pad_h = (window_size - height % window_size) % window_size
|
||||
pad_w = (window_size - width % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
|
||||
pad_height, pad_width = height + pad_h, width + pad_w
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
@@ -902,8 +904,7 @@ class SamVisionLayer(nn.Module):
|
||||
hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
|
||||
)
|
||||
|
||||
if pad_height > height or pad_width > width:
|
||||
hidden_states = hidden_states[:, :height, :width, :].contiguous()
|
||||
hidden_states = hidden_states[:, :height, :width, :].contiguous()
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user