Add dynamic resolution input/interpolate position embedding to SigLIP (#30719)
* Add interpolate positional encoding to siglip * Change # of patches for siglip interpolation test * fix formatting * Apply nit suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -687,3 +687,25 @@ class SiglipModelIntegrationTest(unittest.TestCase):
|
||||
probs = torch.sigmoid(logits_per_image) # these are the probabilities
|
||||
expected_probs = torch.tensor([[3.1937e-01, 3.2463e-05]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(probs, expected_probs, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model_name = "google/siglip-base-patch16-224"
|
||||
model = SiglipModel.from_pretrained(model_name).to(torch_device)
|
||||
|
||||
# 640 x 480 image
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
processor = SiglipProcessor.from_pretrained(model_name, do_resize=False, size={"height": 480, "width": 640})
|
||||
|
||||
inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the shape
|
||||
# patch size = 16
|
||||
# batch size 1, (640/16) * (480/16) = 1200 patches, 768 hidden size
|
||||
expected_shape = torch.Size((1, 1200, 768))
|
||||
|
||||
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user