Add dynamic resolution input/interpolate position embedding to SigLIP (#30719)
* Add interpolate positional encoding to siglip * Change # of patches for siglip interpolation test * fix formatting * Apply nit suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -265,11 +265,53 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method is an adapted method for SigLIP (due to SigLIP not having class embedding unlike other ViTs)
|
||||
that allows the model to interpolate the pre-trained position encodings such that it can be usable on
|
||||
higher resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
position_embeddings = self.position_embedding.weight.unsqueeze(0)
|
||||
num_patches = embeddings.shape[1]
|
||||
num_positions = position_embeddings.shape[1]
|
||||
if num_patches == num_positions and height == width:
|
||||
return position_embeddings
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
height, width = height + 0.1, width + 0.1
|
||||
|
||||
patch_pos_embed = position_embeddings.reshape(
|
||||
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return patch_pos_embed
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.shape
|
||||
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -564,6 +606,8 @@ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@@ -601,6 +645,8 @@ SIGLIP_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@@ -848,6 +894,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -859,7 +906,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
@@ -935,6 +982,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -965,6 +1013,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
|
||||
@@ -1055,6 +1104,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -1092,6 +1142,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
pooled_output = vision_outputs[1]
|
||||
@@ -1110,6 +1161,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple, SiglipOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -1152,6 +1204,7 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
@@ -1226,6 +1279,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[tuple, ImageClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@@ -1271,6 +1325,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
@@ -687,3 +687,25 @@ class SiglipModelIntegrationTest(unittest.TestCase):
|
||||
probs = torch.sigmoid(logits_per_image) # these are the probabilities
|
||||
expected_probs = torch.tensor([[3.1937e-01, 3.2463e-05]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(probs, expected_probs, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model_name = "google/siglip-base-patch16-224"
|
||||
model = SiglipModel.from_pretrained(model_name).to(torch_device)
|
||||
|
||||
# 640 x 480 image
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
processor = SiglipProcessor.from_pretrained(model_name, do_resize=False, size={"height": 480, "width": 640})
|
||||
|
||||
inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the shape
|
||||
# patch size = 16
|
||||
# batch size 1, (640/16) * (480/16) = 1200 patches, 768 hidden size
|
||||
expected_shape = torch.Size((1, 1200, 768))
|
||||
|
||||
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user