owlvit/2 dynamic input resolution (#34764)
* owlvit/2 dynamic input resolution. * adapt box grid to patch_dim_h patch_dim_w * fix ci * clarify variable naming * clarify variable naming.. * compute box_bias dynamically inside box_predictor * change style part of code * [run-slow] owlvit, owlv2
This commit is contained in:
@@ -821,6 +821,144 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
expected_logits = torch.tensor([[3.4613, 0.9403]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
model = OwlViTModel.from_pretrained(model_name).to(torch_device)
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
processor.image_processor.size = {"height": 800, "width": 800}
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
self.assertEqual(
|
||||
outputs.logits_per_image.shape,
|
||||
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs.logits_per_text.shape,
|
||||
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||
)
|
||||
expected_logits = torch.tensor([[3.6278, 0.8861]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
|
||||
expected_shape = torch.Size((1, 626, 768))
|
||||
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
|
||||
|
||||
# OwlViTForObjectDetection part.
|
||||
model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.0680, 0.0422, 0.1347], [0.2071, 0.0450, 0.4146], [0.2000, 0.0418, 0.3476]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
|
||||
query_image = prepare_img()
|
||||
inputs = processor(
|
||||
images=image,
|
||||
query_images=query_image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# No need to check the logits, we just check inference runs fine.
|
||||
num_queries = int((inputs.pixel_values.shape[-1] / model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
# Deactivate interpolate_pos_encoding on same model, and use default image size.
|
||||
# Verify the dynamic change caused by the activation/deactivation of interpolate_pos_encoding of variables: (self.sqrt_num_patch_h, self.sqrt_num_patch_w), self.box_bias from (OwlViTForObjectDetection).
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=False)
|
||||
|
||||
num_queries = int((inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
expected_default_box_bias = torch.tensor(
|
||||
[
|
||||
[-3.1332, -3.1332, -3.1332, -3.1332],
|
||||
[-2.3968, -3.1332, -3.1332, -3.1332],
|
||||
[-1.9452, -3.1332, -3.1332, -3.1332],
|
||||
]
|
||||
)
|
||||
self.assertTrue(torch.allclose(model.box_bias[:3, :4], expected_default_box_bias, atol=1e-4))
|
||||
|
||||
# Interpolate with any resolution size.
|
||||
processor.image_processor.size = {"height": 1264, "width": 1024}
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
num_queries = int(
|
||||
(inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size)
|
||||
* (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size)
|
||||
)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.0499, 0.0301, 0.0983], [0.2244, 0.0365, 0.4663], [0.1387, 0.0314, 0.1859]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||
|
||||
query_image = prepare_img()
|
||||
inputs = processor(
|
||||
images=image,
|
||||
query_images=query_image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.image_guided_detection(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# No need to check the logits, we just check inference runs fine.
|
||||
num_queries = int(
|
||||
(inputs.pixel_values.shape[-2] // model.config.vision_config.patch_size)
|
||||
* (inputs.pixel_values.shape[-1] // model.config.vision_config.patch_size)
|
||||
)
|
||||
self.assertEqual(outputs.target_pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
|
||||
@slow
|
||||
def test_inference_object_detection(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
|
||||
Reference in New Issue
Block a user