diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index f87cd77df6..d3caf9ac9c 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -623,10 +623,16 @@ class SamPromptEncoder(nn.Module): input_shape = (self.input_image_size, self.input_image_size) point_embedding = self.shared_embedding(points, input_shape) - point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) - point_embedding[labels == -10] = 0.0 # ignore points + # This is required for the ONNX export. The dtype, device need to be explicitely + # specificed as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device), + ) point_embedding[labels == 0] += self.point_embed[0].weight point_embedding[labels == 1] += self.point_embed[1].weight