added interpolation for vitmae model in pytorch as well as tf. (#30732)

* added interpolation for vitmae model in pytorch as well as tf.

* Update modeling_vit_mae.py

irreugalr import fixed

* small changes and proper formatting

* changes suggested in review.

* modified decoder interpolate_func

* arguments and docstring fix

* Apply suggestions from code review

doc fixes

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
BHUVAN M
2024-05-24 20:50:09 +05:30
committed by GitHub
parent a3cdff417b
commit e5103a76cc
4 changed files with 333 additions and 60 deletions

View File

@@ -426,7 +426,7 @@ def prepare_img():
class TFViTMAEModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
@slow
def test_inference_for_pretraining(self):
@@ -457,3 +457,32 @@ class TFViTMAEModelIntegrationTest(unittest.TestCase):
)
tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
@slow
def test_inference_interpolate_pos_encoding(self):
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
# make random mask reproducible across the PT and TF model
np.random.seed(2)
model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, do_resize=False, return_tensors="tf")
# prepare a noise vector that will be also used for testing the TF model
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
noise = np.random.uniform(size=(1, num_patches))
# forward pass
outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True)
# verify the logits
expected_shape = tf.convert_to_tensor([1, 1200, 768])
self.assertEqual(outputs.logits.shape, expected_shape)

View File

@@ -296,7 +296,7 @@ def prepare_img():
class ViTMAEModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
@slow
def test_inference_for_pretraining(self):
@@ -328,3 +328,35 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
# make random mask reproducible across the PT and TF model
np.random.seed(2)
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt", do_resize=False).to(torch_device)
# prepare a noise vector that will be also used for testing the TF model
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
noise = np.random.uniform(size=(1, num_patches))
# forward pass
with torch.no_grad():
outputs = model(
**inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True
)
# verify the logits
expected_shape = torch.Size((1, 1200, 768))
self.assertEqual(outputs.logits.shape, expected_shape)