Add dynamic resolution input/interpolate position embedding to deit (#31131)

* Added interpolate pos encoding feature and test to deit

* Added interpolate pos encoding feature and test for deit TF model

* readded accidentally delted test for multi_gpu

* storing only patch_size instead of entire config and removed commented code

* Update modeling_tf_deit.py to remove extra line

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Kristen Pereira
2024-06-04 05:29:01 -04:00
committed by GitHub
parent d64e4da713
commit de460e28e1
4 changed files with 161 additions and 14 deletions

View File

@@ -423,6 +423,28 @@ class DeiTModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224").to(
torch_device
)
image_processor = self.default_image_processor
# image size is {"height": 480, "width": 640}
image = prepare_img()
image_processor.size = {"height": 480, "width": 640}
# center crop set to False so image is not center cropped to 224x224
inputs = image_processor(images=image, return_tensors="pt", do_center_crop=False).to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
@slow
@require_accelerate
@require_torch_accelerator