Fix: repeat per sample for SAM image embeddings (#25074)

Repeat per sample for SAM image embeddings
This commit is contained in:
Xiaoke Huang
2023-07-25 20:30:14 +08:00
committed by GitHub
parent cb8abee511
commit 1dbc1440a7
2 changed files with 4 additions and 4 deletions

View File

@@ -507,8 +507,8 @@ class SamMaskDecoder(nn.Module):
# Expand per-image data in batch direction to be per-point
image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = image_embeddings.repeat(point_batch_size, 1, 1, 1)
image_positional_embeddings = image_positional_embeddings.repeat(point_batch_size, 1, 1, 1)
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
# Run the transformer, image_positional_embedding are consumed
point_embedding, image_embeddings, attentions = self.transformer(