added interpolation for vitmae model in pytorch as well as tf. (#30732)
* added interpolation for vitmae model in pytorch as well as tf. * Update modeling_vit_mae.py irreugalr import fixed * small changes and proper formatting * changes suggested in review. * modified decoder interpolate_func * arguments and docstring fix * Apply suggestions from code review doc fixes Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -240,6 +240,38 @@ class TFViTMAEEmbeddings(keras.layers.Layer):
|
||||
with tf.name_scope(self.patch_embeddings.name):
|
||||
self.patch_embeddings.build(None)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
|
||||
batch_size, seq_len, dim = shape_list(embeddings)
|
||||
num_patches = seq_len - 1
|
||||
|
||||
_, num_positions, _ = shape_list(self.position_embeddings)
|
||||
num_positions -= 1
|
||||
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
patch_pos_embed = tf.image.resize(
|
||||
images=tf.reshape(
|
||||
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
),
|
||||
size=(h0, w0),
|
||||
method="bicubic",
|
||||
)
|
||||
|
||||
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
|
||||
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
|
||||
|
||||
def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
|
||||
"""
|
||||
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
||||
@@ -281,17 +313,23 @@ class TFViTMAEEmbeddings(keras.layers.Layer):
|
||||
|
||||
return sequence_unmasked, mask, ids_restore
|
||||
|
||||
def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
|
||||
def call(
|
||||
self, pixel_values: tf.Tensor, noise: tf.Tensor = None, interpolate_pos_encoding: bool = False
|
||||
) -> tf.Tensor:
|
||||
batch_size, num_channels, height, width = shape_list(pixel_values)
|
||||
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
if interpolate_pos_encoding:
|
||||
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
position_embeddings = self.position_embeddings
|
||||
# add position embeddings w/o cls token
|
||||
embeddings = embeddings + self.position_embeddings[:, 1:, :]
|
||||
embeddings = embeddings + position_embeddings[:, 1:, :]
|
||||
|
||||
# masking: length -> length * config.mask_ratio
|
||||
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
|
||||
|
||||
# append cls token
|
||||
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
|
||||
cls_token = self.cls_token + position_embeddings[:, :1, :]
|
||||
cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
|
||||
embeddings = tf.concat([cls_tokens, embeddings], axis=1)
|
||||
|
||||
@@ -329,7 +367,9 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer):
|
||||
name="projection",
|
||||
)
|
||||
|
||||
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
|
||||
def call(
|
||||
self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False
|
||||
) -> tf.Tensor:
|
||||
batch_size, num_channels, height, width = shape_list(pixel_values)
|
||||
if tf.executing_eagerly():
|
||||
if num_channels != self.num_channels:
|
||||
@@ -337,7 +377,7 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer):
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the"
|
||||
" configuration."
|
||||
)
|
||||
if height != self.image_size[0] or width != self.image_size[1]:
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model"
|
||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||
@@ -741,9 +781,13 @@ class TFViTMAEMainLayer(keras.layers.Layer):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
|
||||
embedding_output, mask, ids_restore = self.embeddings(
|
||||
pixel_values=pixel_values, training=training, noise=noise
|
||||
pixel_values=pixel_values,
|
||||
training=training,
|
||||
noise=noise,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
@@ -874,6 +918,9 @@ VIT_MAE_INPUTS_DOCSTRING = r"""
|
||||
training (`bool`, *optional*, defaults to `False``):
|
||||
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||
behaviors between training and evaluation).
|
||||
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the position encodings at the encoder and decoder.
|
||||
"""
|
||||
|
||||
|
||||
@@ -902,6 +949,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -931,6 +979,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
return outputs
|
||||
@@ -1004,6 +1053,39 @@ class TFViTMAEDecoder(keras.layers.Layer):
|
||||
with tf.name_scope(layer.name):
|
||||
layer.build(None)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings) -> tf.Tensor:
|
||||
"""
|
||||
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
|
||||
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
|
||||
# [batch_size, num_patches + 1, hidden_size]
|
||||
_, num_positions, dim = shape_list(self.decoder_pos_embed)
|
||||
|
||||
# -1 removes the class dimension since we later append it without interpolation
|
||||
seq_len = shape_list(embeddings)[1] - 1
|
||||
num_positions = num_positions - 1
|
||||
|
||||
# Separation of class token and patch tokens
|
||||
class_pos_embed = self.decoder_pos_embed[:, :1, :]
|
||||
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
|
||||
|
||||
# interpolate the position embeddings
|
||||
patch_pos_embed = tf.image.resize(
|
||||
images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)),
|
||||
size=(1, seq_len),
|
||||
method="bicubic",
|
||||
)
|
||||
|
||||
# [1, seq_len, hidden_size]
|
||||
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
|
||||
# Adding the class token back
|
||||
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
@@ -1011,10 +1093,10 @@ class TFViTMAEDecoder(keras.layers.Layer):
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
interpolate_pos_encoding=False,
|
||||
):
|
||||
# embed tokens
|
||||
x = self.decoder_embed(hidden_states)
|
||||
|
||||
# append mask tokens to sequence
|
||||
mask_tokens = tf.tile(
|
||||
self.mask_token,
|
||||
@@ -1023,10 +1105,12 @@ class TFViTMAEDecoder(keras.layers.Layer):
|
||||
x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token
|
||||
x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle
|
||||
x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
decoder_pos_embed = self.interpolate_pos_encoding(x)
|
||||
else:
|
||||
decoder_pos_embed = self.decoder_pos_embed
|
||||
# add pos embed
|
||||
hidden_states = x + self.decoder_pos_embed
|
||||
|
||||
hidden_states = x + decoder_pos_embed
|
||||
# apply Transformer layers (blocks)
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
@@ -1083,11 +1167,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
raise NotImplementedError
|
||||
|
||||
def patchify(self, pixel_values):
|
||||
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
|
||||
"""
|
||||
Args:
|
||||
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
|
||||
Pixel values.
|
||||
interpolate_pos_encoding (`bool`, default `False`):
|
||||
interpolation flag passed during the forward pass.
|
||||
|
||||
Returns:
|
||||
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
||||
@@ -1099,11 +1185,12 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
||||
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
||||
|
||||
# sanity checks
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(pixel_values)[1],
|
||||
shape_list(pixel_values)[2],
|
||||
message="Make sure the pixel values have a squared size",
|
||||
)
|
||||
if not interpolate_pos_encoding:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(pixel_values)[1],
|
||||
shape_list(pixel_values)[2],
|
||||
message="Make sure the pixel values have a squared size",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(pixel_values)[1] % patch_size,
|
||||
0,
|
||||
@@ -1119,51 +1206,61 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
||||
|
||||
# patchify
|
||||
batch_size = shape_list(pixel_values)[0]
|
||||
num_patches_one_direction = shape_list(pixel_values)[2] // patch_size
|
||||
num_patches_h = shape_list(pixel_values)[1] // patch_size
|
||||
num_patches_w = shape_list(pixel_values)[2] // patch_size
|
||||
patchified_pixel_values = tf.reshape(
|
||||
pixel_values,
|
||||
(batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels),
|
||||
(batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels),
|
||||
)
|
||||
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
|
||||
patchified_pixel_values = tf.reshape(
|
||||
patchified_pixel_values,
|
||||
(batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels),
|
||||
(batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels),
|
||||
)
|
||||
return patchified_pixel_values
|
||||
|
||||
def unpatchify(self, patchified_pixel_values):
|
||||
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
|
||||
"""
|
||||
Args:
|
||||
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
||||
Patchified pixel values.
|
||||
original_image_size (`Tuple[int, int]`, *optional*):
|
||||
Original image size.
|
||||
|
||||
Returns:
|
||||
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
|
||||
Pixel values.
|
||||
"""
|
||||
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
||||
num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5)
|
||||
original_image_size = (
|
||||
original_image_size
|
||||
if original_image_size is not None
|
||||
else (self.config.image_size, self.config.image_size)
|
||||
)
|
||||
original_height, original_width = original_image_size
|
||||
num_patches_h = original_height // patch_size
|
||||
num_patches_w = original_width // patch_size
|
||||
# sanity check
|
||||
tf.debugging.assert_equal(
|
||||
num_patches_one_direction * num_patches_one_direction,
|
||||
num_patches_h * num_patches_w,
|
||||
shape_list(patchified_pixel_values)[1],
|
||||
message="Make sure that the number of patches can be squared",
|
||||
message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}",
|
||||
)
|
||||
|
||||
# unpatchify
|
||||
batch_size = shape_list(patchified_pixel_values)[0]
|
||||
patchified_pixel_values = tf.reshape(
|
||||
patchified_pixel_values,
|
||||
(batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels),
|
||||
(batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels),
|
||||
)
|
||||
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
|
||||
pixel_values = tf.reshape(
|
||||
patchified_pixel_values,
|
||||
(batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels),
|
||||
(batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels),
|
||||
)
|
||||
return pixel_values
|
||||
|
||||
def forward_loss(self, pixel_values, pred, mask):
|
||||
def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
|
||||
"""
|
||||
Args:
|
||||
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
|
||||
@@ -1172,11 +1269,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
||||
Predicted pixel values.
|
||||
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
||||
Tensor indicating which patches are masked (1) and which are not (0).
|
||||
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
||||
interpolation flag passed during the forward pass.
|
||||
|
||||
Returns:
|
||||
`tf.Tensor`: Pixel reconstruction loss.
|
||||
"""
|
||||
target = self.patchify(pixel_values)
|
||||
target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
if self.config.norm_pix_loss:
|
||||
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
|
||||
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
|
||||
@@ -1201,6 +1300,7 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -1234,16 +1334,18 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
latent = outputs.last_hidden_state
|
||||
ids_restore = outputs.ids_restore
|
||||
mask = outputs.mask
|
||||
|
||||
decoder_outputs = self.decoder(latent, ids_restore) # [batch_size, num_patches, patch_size**2*3]
|
||||
# [batch_size, num_patches, patch_size**2*3]
|
||||
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
logits = decoder_outputs.logits
|
||||
|
||||
loss = self.forward_loss(pixel_values, logits, mask)
|
||||
loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, mask, ids_restore) + outputs[2:]
|
||||
|
||||
@@ -223,6 +223,41 @@ class ViTMAEEmbeddings(nn.Module):
|
||||
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
||||
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, 0, :]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:, :]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(h0) != patch_pos_embed.shape[-2] or int(w0) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def random_masking(self, sequence, noise=None):
|
||||
"""
|
||||
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
||||
@@ -255,18 +290,22 @@ class ViTMAEEmbeddings(nn.Module):
|
||||
|
||||
return sequence_unmasked, mask, ids_restore
|
||||
|
||||
def forward(self, pixel_values, noise=None):
|
||||
def forward(self, pixel_values, noise=None, interpolate_pos_encoding: bool = False):
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
if interpolate_pos_encoding:
|
||||
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
position_embeddings = self.position_embeddings
|
||||
|
||||
# add position embeddings w/o cls token
|
||||
embeddings = embeddings + self.position_embeddings[:, 1:, :]
|
||||
embeddings = embeddings + position_embeddings[:, 1:, :]
|
||||
|
||||
# masking: length -> length * config.mask_ratio
|
||||
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
|
||||
|
||||
# append cls token
|
||||
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
|
||||
cls_token = self.cls_token + position_embeddings[:, :1, :]
|
||||
cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
|
||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||
|
||||
@@ -294,13 +333,14 @@ class ViTMAEPatchEmbeddings(nn.Module):
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
if height != self.image_size[0] or width != self.image_size[1]:
|
||||
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
@@ -657,6 +697,9 @@ VIT_MAE_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
||||
Whether to interpolate the pre-trained position encodings. This is mainly used to use the model on higher
|
||||
resolution images.
|
||||
"""
|
||||
|
||||
|
||||
@@ -698,6 +741,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple, ViTMAEModelOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -735,7 +779,9 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise)
|
||||
embedding_output, mask, ids_restore = self.embeddings(
|
||||
pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@@ -785,6 +831,47 @@ class ViTMAEDecoder(nn.Module):
|
||||
self.config = config
|
||||
self.initialize_weights(num_patches)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
|
||||
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
|
||||
# -1 removes the class dimension since we later append it without interpolation
|
||||
embeddings_positions = embeddings.shape[1] - 1
|
||||
num_positions = self.decoder_pos_embed.shape[1] - 1
|
||||
|
||||
# Separation of class token and patch tokens
|
||||
class_pos_embed = self.decoder_pos_embed[:, 0, :]
|
||||
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
|
||||
|
||||
# To retain the final 3d tensor with the required dimensions
|
||||
dim = self.decoder_pos_embed.shape[-1]
|
||||
|
||||
# Increasing a dimension to enable bicubic interpolation
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim)
|
||||
|
||||
# permute to bring the dimension to be interpolated, to the last
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
# Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
|
||||
# 1 keeps the other dimension constant
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(1, embeddings_positions / num_positions),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
# Converting back to the original shape
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
# Adding the class token back
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def initialize_weights(self, num_patches):
|
||||
# initialize (and freeze) position embeddings by sin-cos embedding
|
||||
decoder_pos_embed = get_2d_sincos_pos_embed(
|
||||
@@ -802,6 +889,7 @@ class ViTMAEDecoder(nn.Module):
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
):
|
||||
# embed tokens
|
||||
x = self.decoder_embed(hidden_states)
|
||||
@@ -812,9 +900,12 @@ class ViTMAEDecoder(nn.Module):
|
||||
# unshuffle
|
||||
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
|
||||
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
||||
|
||||
# add pos embed
|
||||
hidden_states = x + self.decoder_pos_embed
|
||||
if interpolate_pos_encoding:
|
||||
decoder_pos_embed = self.interpolate_pos_encoding(x)
|
||||
else:
|
||||
decoder_pos_embed = self.decoder_pos_embed
|
||||
hidden_states = x + decoder_pos_embed
|
||||
|
||||
# apply Transformer layers (blocks)
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@@ -893,11 +984,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def patchify(self, pixel_values):
|
||||
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
|
||||
"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values.
|
||||
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
||||
interpolation flag passed during the forward pass.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
||||
@@ -905,7 +998,9 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
||||
"""
|
||||
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
||||
# sanity checks
|
||||
if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0):
|
||||
if not interpolate_pos_encoding and (
|
||||
pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0
|
||||
):
|
||||
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
|
||||
if pixel_values.shape[1] != num_channels:
|
||||
raise ValueError(
|
||||
@@ -914,38 +1009,50 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
||||
|
||||
# patchify
|
||||
batch_size = pixel_values.shape[0]
|
||||
num_patches_one_direction = pixel_values.shape[2] // patch_size
|
||||
num_patches_h = pixel_values.shape[2] // patch_size
|
||||
num_patches_w = pixel_values.shape[3] // patch_size
|
||||
patchified_pixel_values = pixel_values.reshape(
|
||||
batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size
|
||||
batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size
|
||||
)
|
||||
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
|
||||
patchified_pixel_values = patchified_pixel_values.reshape(
|
||||
batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels
|
||||
batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels
|
||||
)
|
||||
return patchified_pixel_values
|
||||
|
||||
def unpatchify(self, patchified_pixel_values):
|
||||
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
|
||||
"""
|
||||
Args:
|
||||
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
||||
Patchified pixel values.
|
||||
original_image_size (`Tuple[int, int]`, *optional*):
|
||||
Original image size.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
||||
Pixel values.
|
||||
"""
|
||||
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
||||
num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5)
|
||||
original_image_size = (
|
||||
original_image_size
|
||||
if original_image_size is not None
|
||||
else (self.config.image_size, self.config.image_size)
|
||||
)
|
||||
original_height, original_width = original_image_size
|
||||
num_patches_h = original_height // patch_size
|
||||
num_patches_w = original_width // patch_size
|
||||
# sanity check
|
||||
if num_patches_one_direction**2 != patchified_pixel_values.shape[1]:
|
||||
raise ValueError("Make sure that the number of patches can be squared")
|
||||
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
|
||||
raise ValueError(
|
||||
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
|
||||
)
|
||||
|
||||
# unpatchify
|
||||
batch_size = patchified_pixel_values.shape[0]
|
||||
patchified_pixel_values = patchified_pixel_values.reshape(
|
||||
batch_size,
|
||||
num_patches_one_direction,
|
||||
num_patches_one_direction,
|
||||
num_patches_h,
|
||||
num_patches_w,
|
||||
patch_size,
|
||||
patch_size,
|
||||
num_channels,
|
||||
@@ -954,12 +1061,12 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
||||
pixel_values = patchified_pixel_values.reshape(
|
||||
batch_size,
|
||||
num_channels,
|
||||
num_patches_one_direction * patch_size,
|
||||
num_patches_one_direction * patch_size,
|
||||
num_patches_h * patch_size,
|
||||
num_patches_w * patch_size,
|
||||
)
|
||||
return pixel_values
|
||||
|
||||
def forward_loss(self, pixel_values, pred, mask):
|
||||
def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
|
||||
"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
@@ -968,11 +1075,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
||||
Predicted pixel values.
|
||||
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||
Tensor indicating which patches are masked (1) and which are not (0).
|
||||
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
||||
interpolation flag passed during the forward pass.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Pixel reconstruction loss.
|
||||
"""
|
||||
target = self.patchify(pixel_values)
|
||||
target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
if self.config.norm_pix_loss:
|
||||
mean = target.mean(dim=-1, keepdim=True)
|
||||
var = target.var(dim=-1, keepdim=True)
|
||||
@@ -980,7 +1089,6 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
||||
|
||||
loss = (pred - target) ** 2
|
||||
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
||||
|
||||
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
||||
return loss
|
||||
|
||||
@@ -994,6 +1102,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple, ViTMAEForPreTrainingOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -1026,16 +1135,17 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
latent = outputs.last_hidden_state
|
||||
ids_restore = outputs.ids_restore
|
||||
mask = outputs.mask
|
||||
|
||||
decoder_outputs = self.decoder(latent, ids_restore)
|
||||
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
|
||||
|
||||
loss = self.forward_loss(pixel_values, logits, mask)
|
||||
loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, mask, ids_restore) + outputs[2:]
|
||||
|
||||
@@ -426,7 +426,7 @@ def prepare_img():
|
||||
class TFViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_image_processor(self):
|
||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
|
||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
|
||||
|
||||
@slow
|
||||
def test_inference_for_pretraining(self):
|
||||
@@ -457,3 +457,32 @@ class TFViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
||||
# to visualize self-attention on higher resolution images.
|
||||
|
||||
# make random mask reproducible across the PT and TF model
|
||||
np.random.seed(2)
|
||||
|
||||
model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, do_resize=False, return_tensors="tf")
|
||||
|
||||
# prepare a noise vector that will be also used for testing the TF model
|
||||
# (this way we can ensure that the PT and TF models operate on the same inputs)
|
||||
vit_mae_config = ViTMAEConfig()
|
||||
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
|
||||
noise = np.random.uniform(size=(1, num_patches))
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = tf.convert_to_tensor([1, 1200, 768])
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
@@ -296,7 +296,7 @@ def prepare_img():
|
||||
class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_image_processor(self):
|
||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
|
||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
|
||||
|
||||
@slow
|
||||
def test_inference_for_pretraining(self):
|
||||
@@ -328,3 +328,35 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
||||
# to visualize self-attention on higher resolution images.
|
||||
|
||||
# make random mask reproducible across the PT and TF model
|
||||
np.random.seed(2)
|
||||
|
||||
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt", do_resize=False).to(torch_device)
|
||||
|
||||
# prepare a noise vector that will be also used for testing the TF model
|
||||
# (this way we can ensure that the PT and TF models operate on the same inputs)
|
||||
vit_mae_config = ViTMAEConfig()
|
||||
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
|
||||
noise = np.random.uniform(size=(1, num_patches))
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
**inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True
|
||||
)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1200, 768))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user