owlvit/2 dynamic input resolution (#34764)

* owlvit/2 dynamic input resolution.

* adapt box grid to patch_dim_h patch_dim_w

* fix ci

* clarify variable naming

* clarify variable naming..

* compute box_bias dynamically inside box_predictor

* change style part of code

* [run-slow] owlvit, owlv2
This commit is contained in:
bastrob
2024-12-21 09:51:09 +01:00
committed by GitHub
parent 608e163b52
commit 8f38f58f3d
4 changed files with 565 additions and 73 deletions

View File

@@ -828,6 +828,144 @@ class Owlv2ModelIntegrationTest(unittest.TestCase):
expected_logits = torch.tensor([[-6.2229, -8.2601]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
@slow
def test_inference_interpolate_pos_encoding(self):
model_name = "google/owlv2-base-patch16"
model = Owlv2Model.from_pretrained(model_name).to(torch_device)
processor = OwlViTProcessor.from_pretrained(model_name)
processor.image_processor.size = {"height": 1024, "width": 1024}
image = prepare_img()
inputs = processor(
text=[["a photo of a cat", "a photo of a dog"]],
images=image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the logits
self.assertEqual(
outputs.logits_per_image.shape,
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
)
self.assertEqual(
outputs.logits_per_text.shape,
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
)
expected_logits = torch.tensor([[-6.2520, -8.2970]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
expected_shape = torch.Size((1, 4097, 768))
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
# Owlv2ForObjectDetection part.
model = Owlv2ForObjectDetection.from_pretrained(model_name).to(torch_device)
processor.image_processor.size = {"height": 1024, "width": 1024}
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
num_queries = int((inputs.pixel_values.shape[-1] / model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
expected_slice_boxes = torch.tensor(
[[0.2407, 0.0553, 0.4636], [0.1082, 0.0494, 0.1861], [0.2459, 0.0527, 0.4398]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
model = Owlv2ForObjectDetection.from_pretrained(model_name).to(torch_device)
query_image = prepare_img()
inputs = processor(
images=image,
query_images=query_image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True)
# No need to check the logits, we just check inference runs fine.
num_queries = int((inputs.pixel_values.shape[-1] / model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
# Deactivate interpolate_pos_encoding on same model, and use default image size.
# Verify the dynamic change caused by the activation/deactivation of interpolate_pos_encoding of variables: self.sqrt_num_patches, self.box_bias from (OwlViTForObjectDetection).
processor = OwlViTProcessor.from_pretrained(model_name)
image = prepare_img()
inputs = processor(
text=[["a photo of a cat", "a photo of a dog"]],
images=image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=False)
num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
expected_default_box_bias = torch.tensor(
[
[-4.0717, -4.0717, -4.0717, -4.0717],
[-3.3644, -4.0717, -4.0717, -4.0717],
[-2.9425, -4.0717, -4.0717, -4.0717],
]
)
self.assertTrue(torch.allclose(model.box_bias[:3, :4], expected_default_box_bias, atol=1e-4))
# Interpolate with any resolution size.
processor.image_processor.size = {"height": 1264, "width": 1024}
image = prepare_img()
inputs = processor(
text=[["a photo of a cat", "a photo of a dog"]],
images=image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
num_queries = int(
(inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size)
* (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size)
)
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
expected_slice_boxes = torch.tensor(
[[0.2438, 0.0945, 0.4675], [0.1361, 0.0431, 0.2406], [0.2465, 0.0428, 0.4429]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
query_image = prepare_img()
inputs = processor(
images=image,
query_images=query_image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True)
# No need to check the logits, we just check inference runs fine.
num_queries = int(
(inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size)
* (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size)
)
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
@slow
def test_inference_object_detection(self):
model_name = "google/owlv2-base-patch16"

View File

@@ -821,6 +821,144 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
@slow
def test_inference_interpolate_pos_encoding(self):
model_name = "google/owlvit-base-patch32"
model = OwlViTModel.from_pretrained(model_name).to(torch_device)
processor = OwlViTProcessor.from_pretrained(model_name)
processor.image_processor.size = {"height": 800, "width": 800}
image = prepare_img()
inputs = processor(
text=[["a photo of a cat", "a photo of a dog"]],
images=image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the logits
self.assertEqual(
outputs.logits_per_image.shape,
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
)
self.assertEqual(
outputs.logits_per_text.shape,
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
)
expected_logits = torch.tensor([[3.6278, 0.8861]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
expected_shape = torch.Size((1, 626, 768))
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
# OwlViTForObjectDetection part.
model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
expected_slice_boxes = torch.tensor(
[[0.0680, 0.0422, 0.1347], [0.2071, 0.0450, 0.4146], [0.2000, 0.0418, 0.3476]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
query_image = prepare_img()
inputs = processor(
images=image,
query_images=query_image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True)
# No need to check the logits, we just check inference runs fine.
num_queries = int((inputs.pixel_values.shape[-1] / model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
# Deactivate interpolate_pos_encoding on same model, and use default image size.
# Verify the dynamic change caused by the activation/deactivation of interpolate_pos_encoding of variables: (self.sqrt_num_patch_h, self.sqrt_num_patch_w), self.box_bias from (OwlViTForObjectDetection).
processor = OwlViTProcessor.from_pretrained(model_name)
image = prepare_img()
inputs = processor(
text=[["a photo of a cat", "a photo of a dog"]],
images=image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=False)
num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
expected_default_box_bias = torch.tensor(
[
[-3.1332, -3.1332, -3.1332, -3.1332],
[-2.3968, -3.1332, -3.1332, -3.1332],
[-1.9452, -3.1332, -3.1332, -3.1332],
]
)
self.assertTrue(torch.allclose(model.box_bias[:3, :4], expected_default_box_bias, atol=1e-4))
# Interpolate with any resolution size.
processor.image_processor.size = {"height": 1264, "width": 1024}
image = prepare_img()
inputs = processor(
text=[["a photo of a cat", "a photo of a dog"]],
images=image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
num_queries = int(
(inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size)
* (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size)
)
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
expected_slice_boxes = torch.tensor(
[[0.0499, 0.0301, 0.0983], [0.2244, 0.0365, 0.4663], [0.1387, 0.0314, 0.1859]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
query_image = prepare_img()
inputs = processor(
images=image,
query_images=query_image,
max_length=16,
padding="max_length",
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True)
# No need to check the logits, we just check inference runs fine.
num_queries = int(
(inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size)
* (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size)
)
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
@slow
def test_inference_object_detection(self):
model_name = "google/owlvit-base-patch32"