From 1dbc1440a7d750f69815a56ebda20d6dc00a37ed Mon Sep 17 00:00:00 2001 From: Xiaoke Huang Date: Tue, 25 Jul 2023 20:30:14 +0800 Subject: [PATCH] Fix: repeat per sample for SAM image embeddings (#25074) Repeat per sample for SAM image embeddings --- src/transformers/models/sam/modeling_sam.py | 4 ++-- src/transformers/models/sam/modeling_tf_sam.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 43d88232e3..3b8e1aba71 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -507,8 +507,8 @@ class SamMaskDecoder(nn.Module): # Expand per-image data in batch direction to be per-point image_embeddings = image_embeddings + dense_prompt_embeddings - image_embeddings = image_embeddings.repeat(point_batch_size, 1, 1, 1) - image_positional_embeddings = image_positional_embeddings.repeat(point_batch_size, 1, 1, 1) + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) # Run the transformer, image_positional_embedding are consumed point_embedding, image_embeddings, attentions = self.transformer( diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index a47b091a09..48b25ae134 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -517,8 +517,8 @@ class TFSamMaskDecoder(tf.keras.layers.Layer): point_embeddings = tf.cast(tokens, self.iou_token.dtype) image_embeddings = image_embeddings + dense_prompt_embeddings - image_embeddings = tf.tile(image_embeddings, [point_batch_size, 1, 1, 1]) - image_positional_embeddings = tf.tile(image_positional_embeddings, [point_batch_size, 1, 1, 1]) + image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0) + image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0) point_embedding, image_embeddings, attentions = self.transformer( point_embeddings=point_embeddings,