Enable dynamic resolution input for Beit (#31053)

* Initial attempt

* Updates: PR suggestions

* Interpolate the relative position bias when interpolate_pos_encoding is True

* Add slow tag for the added tests

* Add in DATA2VEC_VISION_INPUTS_DOCSTRING
This commit is contained in:
Omar Salman
2024-06-06 18:47:41 +05:00
committed by GitHub
parent 99895ae5e2
commit 681183974a
4 changed files with 260 additions and 20 deletions

View File

@@ -545,6 +545,31 @@ class BeitModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((160, 160))
self.assertEqual(segmentation[0].shape, expected_shape)
@slow
def test_inference_interpolate_pos_encoding(self):
model_name = "microsoft/beit-base-patch16-224-pt22k"
model = BeitModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(torch_device)
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
processor = BeitImageProcessor.from_pretrained(model_name)
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
pixel_values = inputs.pixel_values.to(torch_device)
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
# images than what the model supports.
self.assertFalse(processor.do_center_crop)
with torch.no_grad():
with self.assertRaises(ValueError, msg="doesn't match model"):
model(pixel_values, interpolate_pos_encoding=False)
# with interpolate_pos_encoding being True the model should process the higher resolution image
# successfully and produce the expected output.
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
expected_shape = torch.Size((1, 1801, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
@require_torch
class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):