Make sam ONNX exportable (#22915)
* fix code not exportable * fix * Update src/transformers/models/sam/modeling_sam.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -623,10 +623,16 @@ class SamPromptEncoder(nn.Module):
|
||||
input_shape = (self.input_image_size, self.input_image_size)
|
||||
point_embedding = self.shared_embedding(points, input_shape)
|
||||
|
||||
point_embedding[labels == -1] = 0.0
|
||||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
||||
# torch.where and expanding the labels tensor is required by the ONNX export
|
||||
point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
|
||||
|
||||
point_embedding[labels == -10] = 0.0 # ignore points
|
||||
# This is required for the ONNX export. The dtype, device need to be explicitely
|
||||
# specificed as otherwise torch.onnx.export interprets as double
|
||||
point_embedding = torch.where(
|
||||
labels[..., None] != -10,
|
||||
point_embedding,
|
||||
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
|
||||
)
|
||||
|
||||
point_embedding[labels == 0] += self.point_embed[0].weight
|
||||
point_embedding[labels == 1] += self.point_embed[1].weight
|
||||
|
||||
Reference in New Issue
Block a user