Fix: repeat per sample for SAM image embeddings (#25074)

Repeat per sample for SAM image embeddings
This commit is contained in:
Xiaoke Huang
2023-07-25 20:30:14 +08:00
committed by GitHub
parent cb8abee511
commit 1dbc1440a7
2 changed files with 4 additions and 4 deletions

View File

@@ -507,8 +507,8 @@ class SamMaskDecoder(nn.Module):
# Expand per-image data in batch direction to be per-point # Expand per-image data in batch direction to be per-point
image_embeddings = image_embeddings + dense_prompt_embeddings image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = image_embeddings.repeat(point_batch_size, 1, 1, 1) image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
image_positional_embeddings = image_positional_embeddings.repeat(point_batch_size, 1, 1, 1) image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
# Run the transformer, image_positional_embedding are consumed # Run the transformer, image_positional_embedding are consumed
point_embedding, image_embeddings, attentions = self.transformer( point_embedding, image_embeddings, attentions = self.transformer(

View File

@@ -517,8 +517,8 @@ class TFSamMaskDecoder(tf.keras.layers.Layer):
point_embeddings = tf.cast(tokens, self.iou_token.dtype) point_embeddings = tf.cast(tokens, self.iou_token.dtype)
image_embeddings = image_embeddings + dense_prompt_embeddings image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = tf.tile(image_embeddings, [point_batch_size, 1, 1, 1]) image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0)
image_positional_embeddings = tf.tile(image_positional_embeddings, [point_batch_size, 1, 1, 1]) image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0)
point_embedding, image_embeddings, attentions = self.transformer( point_embedding, image_embeddings, attentions = self.transformer(
point_embeddings=point_embeddings, point_embeddings=point_embeddings,