Enable dynamic resolution input for Beit (#31053)
* Initial attempt * Updates: PR suggestions * Interpolate the relative position bias when interpolate_pos_encoding is True * Add slow tag for the added tests * Add in DATA2VEC_VISION_INPUTS_DOCSTRING
This commit is contained in:
@@ -545,6 +545,31 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
expected_shape = torch.Size((160, 160))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model_name = "microsoft/beit-base-patch16-224-pt22k"
|
||||
model = BeitModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(torch_device)
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
processor = BeitImageProcessor.from_pretrained(model_name)
|
||||
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
|
||||
# images than what the model supports.
|
||||
self.assertFalse(processor.do_center_crop)
|
||||
with torch.no_grad():
|
||||
with self.assertRaises(ValueError, msg="doesn't match model"):
|
||||
model(pixel_values, interpolate_pos_encoding=False)
|
||||
|
||||
# with interpolate_pos_encoding being True the model should process the higher resolution image
|
||||
# successfully and produce the expected output.
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||
|
||||
expected_shape = torch.Size((1, 1801, 768))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
|
||||
@@ -341,3 +341,30 @@ class Data2VecVisionModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]]
|
||||
self.assertEqual(logits[0].topk(2).indices.cpu().tolist(), expected_top2)
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model_name = "facebook/data2vec-vision-base-ft1k"
|
||||
model = Data2VecVisionModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
processor = BeitImageProcessor.from_pretrained("facebook/data2vec-vision-base-ft1k")
|
||||
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
|
||||
# images than what the model supports.
|
||||
self.assertFalse(processor.do_center_crop)
|
||||
with torch.no_grad():
|
||||
with self.assertRaises(ValueError, msg="doesn't match model"):
|
||||
model(pixel_values, interpolate_pos_encoding=False)
|
||||
|
||||
# with interpolate_pos_encoding being True the model should process the higher resolution image
|
||||
# successfully and produce the expected output.
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||
|
||||
expected_shape = torch.Size((1, 1801, 768))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user