added interpolation for vitmae model in pytorch as well as tf. (#30732)
* added interpolation for vitmae model in pytorch as well as tf. * Update modeling_vit_mae.py irreugalr import fixed * small changes and proper formatting * changes suggested in review. * modified decoder interpolate_func * arguments and docstring fix * Apply suggestions from code review doc fixes Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -426,7 +426,7 @@ def prepare_img():
|
||||
class TFViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_image_processor(self):
|
||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
|
||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
|
||||
|
||||
@slow
|
||||
def test_inference_for_pretraining(self):
|
||||
@@ -457,3 +457,32 @@ class TFViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
||||
# to visualize self-attention on higher resolution images.
|
||||
|
||||
# make random mask reproducible across the PT and TF model
|
||||
np.random.seed(2)
|
||||
|
||||
model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, do_resize=False, return_tensors="tf")
|
||||
|
||||
# prepare a noise vector that will be also used for testing the TF model
|
||||
# (this way we can ensure that the PT and TF models operate on the same inputs)
|
||||
vit_mae_config = ViTMAEConfig()
|
||||
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
|
||||
noise = np.random.uniform(size=(1, num_patches))
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = tf.convert_to_tensor([1, 1200, 768])
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
@@ -296,7 +296,7 @@ def prepare_img():
|
||||
class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_image_processor(self):
|
||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
|
||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
|
||||
|
||||
@slow
|
||||
def test_inference_for_pretraining(self):
|
||||
@@ -328,3 +328,35 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
||||
# to visualize self-attention on higher resolution images.
|
||||
|
||||
# make random mask reproducible across the PT and TF model
|
||||
np.random.seed(2)
|
||||
|
||||
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt", do_resize=False).to(torch_device)
|
||||
|
||||
# prepare a noise vector that will be also used for testing the TF model
|
||||
# (this way we can ensure that the PT and TF models operate on the same inputs)
|
||||
vit_mae_config = ViTMAEConfig()
|
||||
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
|
||||
noise = np.random.uniform(size=(1, num_patches))
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
**inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True
|
||||
)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1200, 768))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user