Add dynamic resolution input/interpolate position embedding to deit (#31131)
* Added interpolate pos encoding feature and test to deit * Added interpolate pos encoding feature and test for deit TF model * readded accidentally delted test for multi_gpu * storing only patch_size instead of entire config and removed commented code * Update modeling_tf_deit.py to remove extra line Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -423,6 +423,28 @@ class DeiTModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
|
||||
# image size is {"height": 480, "width": 640}
|
||||
image = prepare_img()
|
||||
image_processor.size = {"height": 480, "width": 640}
|
||||
# center crop set to False so image is not center cropped to 224x224
|
||||
inputs = image_processor(images=image, return_tensors="pt", do_center_crop=False).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -293,3 +293,20 @@ class DeiTModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice = tf.constant([-1.0266, 0.1912, -1.2861])
|
||||
|
||||
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model = TFDeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
# image size is {"height": 480, "width": 640}
|
||||
image = prepare_img()
|
||||
image_processor.size = {"height": 480, "width": 640}
|
||||
# center crop set to False so image is not center cropped to 224x224
|
||||
inputs = image_processor(images=image, return_tensors="tf", do_center_crop=False)
|
||||
# forward pass
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = tf.TensorShape((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user