diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 393f20cfc8..bf14a4b241 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -223,9 +223,7 @@ class SamAttention(nn.Module): def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: batch, n_heads, n_tokens, c_per_head = hidden_states.shape hidden_states = hidden_states.transpose(1, 2) - return hidden_states.reshape( - batch // max(1, point_batch_size), point_batch_size, n_tokens, n_heads * c_per_head - ) + return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: # Input projections @@ -482,7 +480,7 @@ class SamMaskDecoder(nn.Module): Whether or not to return the attentions tensors of all attention layers. """ batch_size, num_channels, height, width = image_embeddings.shape - point_batch_size = max(1, sparse_prompt_embeddings.shape[1]) + point_batch_size = sparse_prompt_embeddings.shape[1] # Concatenate output tokens output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) @@ -634,8 +632,18 @@ class SamPromptEncoder(nn.Module): torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device), ) - point_embedding[labels == 0] += self.point_embed[0].weight - point_embedding[labels == 1] += self.point_embed[1].weight + point_embedding = torch.where( + (labels == 0)[:, :, :, None], + point_embedding + self.point_embed[0].weight[None, None, :, :], + point_embedding, + ) + + point_embedding = torch.where( + (labels == 1)[:, :, :, None], + point_embedding + self.point_embed[1].weight[None, None, :, :], + point_embedding, + ) + return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: @@ -675,8 +683,7 @@ class SamPromptEncoder(nn.Module): if input_labels is None: raise ValueError("If points are provided, labels must also be provided.") point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) - sparse_embeddings = torch.empty((batch_size, point_batch_size, 0, self.hidden_size), device=target_device) - sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=2) + sparse_embeddings = point_embeddings if input_boxes is not None: batch_size = input_boxes.shape[0] box_embeddings = self._embed_boxes(input_boxes) @@ -692,7 +699,7 @@ class SamPromptEncoder(nn.Module): ) if sparse_embeddings is None: - sparse_embeddings = torch.empty((batch_size, 0, 1, self.hidden_size), device=target_device) + sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device) return sparse_embeddings, dense_embeddings @@ -742,17 +749,13 @@ class SamVisionAttention(nn.Module): Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) - # Interpolate rel pos if needed. - if rel_pos.shape[0] != max_rel_dist: - # Interpolate rel pos. - rel_pos_resized = F.interpolate( - rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), - size=max_rel_dist, - mode="linear", - ) - rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) - else: - rel_pos_resized = rel_pos + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) @@ -865,8 +868,7 @@ class SamVisionLayer(nn.Module): pad_h = (window_size - height % window_size) % window_size pad_w = (window_size - width % window_size) % window_size - if pad_h > 0 or pad_w > 0: - hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) pad_height, pad_width = height + pad_h, width + pad_w hidden_states = hidden_states.reshape( @@ -902,8 +904,7 @@ class SamVisionLayer(nn.Module): hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) ) - if pad_height > height or pad_width > width: - hidden_states = hidden_states[:, :height, :width, :].contiguous() + hidden_states = hidden_states[:, :height, :width, :].contiguous() return hidden_states def forward(