Add dynamic resolution input/interpolate position embedding to SigLIP (#30719)
* Add interpolate positional encoding to siglip * Change # of patches for siglip interpolation test * fix formatting * Apply nit suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -265,11 +265,53 @@ class SiglipVisionEmbeddings(nn.Module):
|
|||||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method is an adapted method for SigLIP (due to SigLIP not having class embedding unlike other ViTs)
|
||||||
|
that allows the model to interpolate the pre-trained position encodings such that it can be usable on
|
||||||
|
higher resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
position_embeddings = self.position_embedding.weight.unsqueeze(0)
|
||||||
|
num_patches = embeddings.shape[1]
|
||||||
|
num_positions = position_embeddings.shape[1]
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return position_embeddings
|
||||||
|
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
height = height // self.patch_size
|
||||||
|
width = width // self.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
height, width = height + 0.1, width + 0.1
|
||||||
|
|
||||||
|
patch_pos_embed = position_embeddings.reshape(
|
||||||
|
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
||||||
|
)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
|
||||||
|
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||||
|
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return patch_pos_embed
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
||||||
|
_, _, height, width = pixel_values.shape
|
||||||
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
||||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
if interpolate_pos_encoding:
|
||||||
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@@ -564,6 +606,8 @@ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -601,6 +645,8 @@ SIGLIP_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -848,6 +894,7 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -859,7 +906,7 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
hidden_states = self.embeddings(pixel_values)
|
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
@@ -935,6 +982,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -965,6 +1013,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1055,6 +1104,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1092,6 +1142,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
pooled_output = vision_outputs[1]
|
pooled_output = vision_outputs[1]
|
||||||
@@ -1110,6 +1161,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple, SiglipOutput]:
|
) -> Union[Tuple, SiglipOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1152,6 +1204,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
text_outputs = self.text_model(
|
text_outputs = self.text_model(
|
||||||
@@ -1226,6 +1279,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[tuple, ImageClassifierOutput]:
|
) -> Union[tuple, ImageClassifierOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
@@ -1271,6 +1325,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|||||||
@@ -687,3 +687,25 @@ class SiglipModelIntegrationTest(unittest.TestCase):
|
|||||||
probs = torch.sigmoid(logits_per_image) # these are the probabilities
|
probs = torch.sigmoid(logits_per_image) # these are the probabilities
|
||||||
expected_probs = torch.tensor([[3.1937e-01, 3.2463e-05]], device=torch_device)
|
expected_probs = torch.tensor([[3.1937e-01, 3.2463e-05]], device=torch_device)
|
||||||
self.assertTrue(torch.allclose(probs, expected_probs, atol=1e-3))
|
self.assertTrue(torch.allclose(probs, expected_probs, atol=1e-3))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_interpolate_pos_encoding(self):
|
||||||
|
model_name = "google/siglip-base-patch16-224"
|
||||||
|
model = SiglipModel.from_pretrained(model_name).to(torch_device)
|
||||||
|
|
||||||
|
# 640 x 480 image
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
processor = SiglipProcessor.from_pretrained(model_name, do_resize=False, size={"height": 480, "width": 640})
|
||||||
|
|
||||||
|
inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||||
|
|
||||||
|
# verify the shape
|
||||||
|
# patch size = 16
|
||||||
|
# batch size 1, (640/16) * (480/16) = 1200 patches, 768 hidden size
|
||||||
|
expected_shape = torch.Size((1, 1200, 768))
|
||||||
|
|
||||||
|
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user