diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 43d88232e3..3b8e1aba71 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -507,8 +507,8 @@ class SamMaskDecoder(nn.Module): # Expand per-image data in batch direction to be per-point image_embeddings = image_embeddings + dense_prompt_embeddings - image_embeddings = image_embeddings.repeat(point_batch_size, 1, 1, 1) - image_positional_embeddings = image_positional_embeddings.repeat(point_batch_size, 1, 1, 1) + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) # Run the transformer, image_positional_embedding are consumed point_embedding, image_embeddings, attentions = self.transformer( diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index a47b091a09..48b25ae134 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -517,8 +517,8 @@ class TFSamMaskDecoder(tf.keras.layers.Layer): point_embeddings = tf.cast(tokens, self.iou_token.dtype) image_embeddings = image_embeddings + dense_prompt_embeddings - image_embeddings = tf.tile(image_embeddings, [point_batch_size, 1, 1, 1]) - image_positional_embeddings = tf.tile(image_positional_embeddings, [point_batch_size, 1, 1, 1]) + image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0) + image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0) point_embedding, image_embeddings, attentions = self.transformer( point_embeddings=point_embeddings,