Fix: repeat per sample for SAM image embeddings (#25074)
Repeat per sample for SAM image embeddings
This commit is contained in:
@@ -507,8 +507,8 @@ class SamMaskDecoder(nn.Module):
|
||||
|
||||
# Expand per-image data in batch direction to be per-point
|
||||
image_embeddings = image_embeddings + dense_prompt_embeddings
|
||||
image_embeddings = image_embeddings.repeat(point_batch_size, 1, 1, 1)
|
||||
image_positional_embeddings = image_positional_embeddings.repeat(point_batch_size, 1, 1, 1)
|
||||
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
|
||||
image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
|
||||
|
||||
# Run the transformer, image_positional_embedding are consumed
|
||||
point_embedding, image_embeddings, attentions = self.transformer(
|
||||
|
||||
Reference in New Issue
Block a user