From e5103a76cc57f4944456d2bca4d59eade3a514f6 Mon Sep 17 00:00:00 2001 From: BHUVAN M <121122109+bhuvanmdev@users.noreply.github.com> Date: Fri, 24 May 2024 20:50:09 +0530 Subject: [PATCH] added interpolation for vitmae model in pytorch as well as tf. (#30732) * added interpolation for vitmae model in pytorch as well as tf. * Update modeling_vit_mae.py irreugalr import fixed * small changes and proper formatting * changes suggested in review. * modified decoder interpolate_func * arguments and docstring fix * Apply suggestions from code review doc fixes Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/vit_mae/modeling_tf_vit_mae.py | 164 ++++++++++++++---- .../models/vit_mae/modeling_vit_mae.py | 164 +++++++++++++++--- .../vit_mae/test_modeling_tf_vit_mae.py | 31 +++- tests/models/vit_mae/test_modeling_vit_mae.py | 34 +++- 4 files changed, 333 insertions(+), 60 deletions(-) diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py index cc7a29b192..5760dbf1ef 100644 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -240,6 +240,38 @@ class TFViTMAEEmbeddings(keras.layers.Layer): with tf.name_scope(self.patch_embeddings.name): self.patch_embeddings.build(None) + def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + batch_size, seq_len, dim = shape_list(embeddings) + num_patches = seq_len - 1 + + _, num_positions, _ = shape_list(self.position_embeddings) + num_positions -= 1 + + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + patch_pos_embed = tf.image.resize( + images=tf.reshape( + patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + ), + size=(h0, w0), + method="bicubic", + ) + + patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) + return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) + def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random @@ -281,17 +313,23 @@ class TFViTMAEEmbeddings(keras.layers.Layer): return sequence_unmasked, mask, ids_restore - def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor: - embeddings = self.patch_embeddings(pixel_values) - + def call( + self, pixel_values: tf.Tensor, noise: tf.Tensor = None, interpolate_pos_encoding: bool = False + ) -> tf.Tensor: + batch_size, num_channels, height, width = shape_list(pixel_values) + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + if interpolate_pos_encoding: + position_embeddings = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embeddings = self.position_embeddings # add position embeddings w/o cls token - embeddings = embeddings + self.position_embeddings[:, 1:, :] + embeddings = embeddings + position_embeddings[:, 1:, :] # masking: length -> length * config.mask_ratio embeddings, mask, ids_restore = self.random_masking(embeddings, noise) # append cls token - cls_token = self.cls_token + self.position_embeddings[:, :1, :] + cls_token = self.cls_token + position_embeddings[:, :1, :] cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1)) embeddings = tf.concat([cls_tokens, embeddings], axis=1) @@ -329,7 +367,9 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer): name="projection", ) - def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: + def call( + self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False + ) -> tf.Tensor: batch_size, num_channels, height, width = shape_list(pixel_values) if tf.executing_eagerly(): if num_channels != self.num_channels: @@ -337,7 +377,7 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer): "Make sure that the channel dimension of the pixel values match with the one set in the" " configuration." ) - if height != self.image_size[0] or width != self.image_size[1]: + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): raise ValueError( f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size[0]}*{self.image_size[1]})." @@ -741,9 +781,13 @@ class TFViTMAEMainLayer(keras.layers.Layer): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, + interpolate_pos_encoding: bool = False, ) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]: embedding_output, mask, ids_restore = self.embeddings( - pixel_values=pixel_values, training=training, noise=noise + pixel_values=pixel_values, + training=training, + noise=noise, + interpolate_pos_encoding=interpolate_pos_encoding, ) # Prepare head mask if needed @@ -874,6 +918,9 @@ VIT_MAE_INPUTS_DOCSTRING = r""" training (`bool`, *optional*, defaults to `False``): Whether or not to use the model in training mode (some modules like dropout modules have different behaviors between training and evaluation). + + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the position encodings at the encoder and decoder. """ @@ -902,6 +949,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, + interpolate_pos_encoding: bool = False, ) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]: r""" Returns: @@ -931,6 +979,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, + interpolate_pos_encoding=interpolate_pos_encoding, ) return outputs @@ -1004,6 +1053,39 @@ class TFViTMAEDecoder(keras.layers.Layer): with tf.name_scope(layer.name): layer.build(None) + def interpolate_pos_encoding(self, embeddings) -> tf.Tensor: + """ + This method is a modified version of the interpolation function for ViT-mae model at the deocder, that + allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + # [batch_size, num_patches + 1, hidden_size] + _, num_positions, dim = shape_list(self.decoder_pos_embed) + + # -1 removes the class dimension since we later append it without interpolation + seq_len = shape_list(embeddings)[1] - 1 + num_positions = num_positions - 1 + + # Separation of class token and patch tokens + class_pos_embed = self.decoder_pos_embed[:, :1, :] + patch_pos_embed = self.decoder_pos_embed[:, 1:, :] + + # interpolate the position embeddings + patch_pos_embed = tf.image.resize( + images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)), + size=(1, seq_len), + method="bicubic", + ) + + # [1, seq_len, hidden_size] + patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) + # Adding the class token back + return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) + def call( self, hidden_states, @@ -1011,10 +1093,10 @@ class TFViTMAEDecoder(keras.layers.Layer): output_attentions=False, output_hidden_states=False, return_dict=True, + interpolate_pos_encoding=False, ): # embed tokens x = self.decoder_embed(hidden_states) - # append mask tokens to sequence mask_tokens = tf.tile( self.mask_token, @@ -1023,10 +1105,12 @@ class TFViTMAEDecoder(keras.layers.Layer): x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token - + if interpolate_pos_encoding: + decoder_pos_embed = self.interpolate_pos_encoding(x) + else: + decoder_pos_embed = self.decoder_pos_embed # add pos embed - hidden_states = x + self.decoder_pos_embed - + hidden_states = x + decoder_pos_embed # apply Transformer layers (blocks) all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -1083,11 +1167,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): def _prune_heads(self, heads_to_prune): raise NotImplementedError - def patchify(self, pixel_values): + def patchify(self, pixel_values, interpolate_pos_encoding: bool = False): """ Args: pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`): Pixel values. + interpolate_pos_encoding (`bool`, default `False`): + interpolation flag passed during the forward pass. Returns: `tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: @@ -1099,11 +1185,12 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) # sanity checks - tf.debugging.assert_equal( - shape_list(pixel_values)[1], - shape_list(pixel_values)[2], - message="Make sure the pixel values have a squared size", - ) + if not interpolate_pos_encoding: + tf.debugging.assert_equal( + shape_list(pixel_values)[1], + shape_list(pixel_values)[2], + message="Make sure the pixel values have a squared size", + ) tf.debugging.assert_equal( shape_list(pixel_values)[1] % patch_size, 0, @@ -1119,51 +1206,61 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): # patchify batch_size = shape_list(pixel_values)[0] - num_patches_one_direction = shape_list(pixel_values)[2] // patch_size + num_patches_h = shape_list(pixel_values)[1] // patch_size + num_patches_w = shape_list(pixel_values)[2] // patch_size patchified_pixel_values = tf.reshape( pixel_values, - (batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels), + (batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels), ) patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values) patchified_pixel_values = tf.reshape( patchified_pixel_values, - (batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels), + (batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels), ) return patchified_pixel_values - def unpatchify(self, patchified_pixel_values): + def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None): """ Args: patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: Patchified pixel values. + original_image_size (`Tuple[int, int]`, *optional*): + Original image size. Returns: `tf.Tensor` of shape `(batch_size, height, width, num_channels)`: Pixel values. """ patch_size, num_channels = self.config.patch_size, self.config.num_channels - num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5) + original_image_size = ( + original_image_size + if original_image_size is not None + else (self.config.image_size, self.config.image_size) + ) + original_height, original_width = original_image_size + num_patches_h = original_height // patch_size + num_patches_w = original_width // patch_size # sanity check tf.debugging.assert_equal( - num_patches_one_direction * num_patches_one_direction, + num_patches_h * num_patches_w, shape_list(patchified_pixel_values)[1], - message="Make sure that the number of patches can be squared", + message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}", ) # unpatchify batch_size = shape_list(patchified_pixel_values)[0] patchified_pixel_values = tf.reshape( patchified_pixel_values, - (batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels), + (batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels), ) patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values) pixel_values = tf.reshape( patchified_pixel_values, - (batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels), + (batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels), ) return pixel_values - def forward_loss(self, pixel_values, pred, mask): + def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False): """ Args: pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`): @@ -1172,11 +1269,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): Predicted pixel values. mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): Tensor indicating which patches are masked (1) and which are not (0). + interpolate_pos_encoding (`bool`, *optional*, default `False`): + interpolation flag passed during the forward pass. Returns: `tf.Tensor`: Pixel reconstruction loss. """ - target = self.patchify(pixel_values) + target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) if self.config.norm_pix_loss: mean = tf.reduce_mean(target, axis=-1, keepdims=True) var = tf.math.reduce_variance(target, axis=-1, keepdims=True) @@ -1201,6 +1300,7 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, + interpolate_pos_encoding: bool = False, ) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]: r""" Returns: @@ -1234,16 +1334,18 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, + interpolate_pos_encoding=interpolate_pos_encoding, ) latent = outputs.last_hidden_state ids_restore = outputs.ids_restore mask = outputs.mask - decoder_outputs = self.decoder(latent, ids_restore) # [batch_size, num_patches, patch_size**2*3] + # [batch_size, num_patches, patch_size**2*3] + decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding) logits = decoder_outputs.logits - loss = self.forward_loss(pixel_values, logits, mask) + loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding) if not return_dict: output = (logits, mask, ids_restore) + outputs[2:] diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index f2afd84fc1..e85d996f47 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -223,6 +223,41 @@ class ViTMAEEmbeddings(nn.Module): # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range) + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + if num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, 0, :] + patch_pos_embed = self.position_embeddings[:, 1:, :] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + if int(h0) != patch_pos_embed.shape[-2] or int(w0) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + def random_masking(self, sequence, noise=None): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random @@ -255,18 +290,22 @@ class ViTMAEEmbeddings(nn.Module): return sequence_unmasked, mask, ids_restore - def forward(self, pixel_values, noise=None): + def forward(self, pixel_values, noise=None, interpolate_pos_encoding: bool = False): batch_size, num_channels, height, width = pixel_values.shape - embeddings = self.patch_embeddings(pixel_values) + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + if interpolate_pos_encoding: + position_embeddings = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embeddings = self.position_embeddings # add position embeddings w/o cls token - embeddings = embeddings + self.position_embeddings[:, 1:, :] + embeddings = embeddings + position_embeddings[:, 1:, :] # masking: length -> length * config.mask_ratio embeddings, mask, ids_restore = self.random_masking(embeddings, noise) # append cls token - cls_token = self.cls_token + self.position_embeddings[:, :1, :] + cls_token = self.cls_token + position_embeddings[:, :1, :] cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) @@ -294,13 +333,14 @@ class ViTMAEPatchEmbeddings(nn.Module): self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values): + def forward(self, pixel_values, interpolate_pos_encoding: bool = False): batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) - if height != self.image_size[0] or width != self.image_size[1]: + + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." ) @@ -657,6 +697,9 @@ VIT_MAE_INPUTS_DOCSTRING = r""" more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, default `False`): + Whether to interpolate the pre-trained position encodings. This is mainly used to use the model on higher + resolution images. """ @@ -698,6 +741,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, ViTMAEModelOutput]: r""" Returns: @@ -735,7 +779,9 @@ class ViTMAEModel(ViTMAEPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise) + embedding_output, mask, ids_restore = self.embeddings( + pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding + ) encoder_outputs = self.encoder( embedding_output, @@ -785,6 +831,47 @@ class ViTMAEDecoder(nn.Module): self.config = config self.initialize_weights(num_patches) + def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: + """ + This method is a modified version of the interpolation function for ViT-mae model at the deocder, that + allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + # -1 removes the class dimension since we later append it without interpolation + embeddings_positions = embeddings.shape[1] - 1 + num_positions = self.decoder_pos_embed.shape[1] - 1 + + # Separation of class token and patch tokens + class_pos_embed = self.decoder_pos_embed[:, 0, :] + patch_pos_embed = self.decoder_pos_embed[:, 1:, :] + + # To retain the final 3d tensor with the required dimensions + dim = self.decoder_pos_embed.shape[-1] + + # Increasing a dimension to enable bicubic interpolation + patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim) + + # permute to bring the dimension to be interpolated, to the last + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + # Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x). + # 1 keeps the other dimension constant + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(1, embeddings_positions / num_positions), + mode="bicubic", + align_corners=False, + ) + + # Converting back to the original shape + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + # Adding the class token back + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + def initialize_weights(self, num_patches): # initialize (and freeze) position embeddings by sin-cos embedding decoder_pos_embed = get_2d_sincos_pos_embed( @@ -802,6 +889,7 @@ class ViTMAEDecoder(nn.Module): output_attentions=False, output_hidden_states=False, return_dict=True, + interpolate_pos_encoding: bool = False, ): # embed tokens x = self.decoder_embed(hidden_states) @@ -812,9 +900,12 @@ class ViTMAEDecoder(nn.Module): # unshuffle x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device)) x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token - # add pos embed - hidden_states = x + self.decoder_pos_embed + if interpolate_pos_encoding: + decoder_pos_embed = self.interpolate_pos_encoding(x) + else: + decoder_pos_embed = self.decoder_pos_embed + hidden_states = x + decoder_pos_embed # apply Transformer layers (blocks) all_hidden_states = () if output_hidden_states else None @@ -893,11 +984,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - def patchify(self, pixel_values): + def patchify(self, pixel_values, interpolate_pos_encoding: bool = False): """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. + interpolate_pos_encoding (`bool`, *optional*, default `False`): + interpolation flag passed during the forward pass. Returns: `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: @@ -905,7 +998,9 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): """ patch_size, num_channels = self.config.patch_size, self.config.num_channels # sanity checks - if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0): + if not interpolate_pos_encoding and ( + pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0 + ): raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size") if pixel_values.shape[1] != num_channels: raise ValueError( @@ -914,38 +1009,50 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): # patchify batch_size = pixel_values.shape[0] - num_patches_one_direction = pixel_values.shape[2] // patch_size + num_patches_h = pixel_values.shape[2] // patch_size + num_patches_w = pixel_values.shape[3] // patch_size patchified_pixel_values = pixel_values.reshape( - batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size + batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size ) patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values) patchified_pixel_values = patchified_pixel_values.reshape( - batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels + batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels ) return patchified_pixel_values - def unpatchify(self, patchified_pixel_values): + def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None): """ Args: patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: Patchified pixel values. + original_image_size (`Tuple[int, int]`, *optional*): + Original image size. Returns: `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: Pixel values. """ patch_size, num_channels = self.config.patch_size, self.config.num_channels - num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5) + original_image_size = ( + original_image_size + if original_image_size is not None + else (self.config.image_size, self.config.image_size) + ) + original_height, original_width = original_image_size + num_patches_h = original_height // patch_size + num_patches_w = original_width // patch_size # sanity check - if num_patches_one_direction**2 != patchified_pixel_values.shape[1]: - raise ValueError("Make sure that the number of patches can be squared") + if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]: + raise ValueError( + f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}" + ) # unpatchify batch_size = patchified_pixel_values.shape[0] patchified_pixel_values = patchified_pixel_values.reshape( batch_size, - num_patches_one_direction, - num_patches_one_direction, + num_patches_h, + num_patches_w, patch_size, patch_size, num_channels, @@ -954,12 +1061,12 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): pixel_values = patchified_pixel_values.reshape( batch_size, num_channels, - num_patches_one_direction * patch_size, - num_patches_one_direction * patch_size, + num_patches_h * patch_size, + num_patches_w * patch_size, ) return pixel_values - def forward_loss(self, pixel_values, pred, mask): + def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False): """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): @@ -968,11 +1075,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): Predicted pixel values. mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Tensor indicating which patches are masked (1) and which are not (0). + interpolate_pos_encoding (`bool`, *optional*, default `False`): + interpolation flag passed during the forward pass. Returns: `torch.FloatTensor`: Pixel reconstruction loss. """ - target = self.patchify(pixel_values) + target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) if self.config.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) @@ -980,7 +1089,6 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch - loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss @@ -994,6 +1102,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, ViTMAEForPreTrainingOutput]: r""" Returns: @@ -1026,16 +1135,17 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) latent = outputs.last_hidden_state ids_restore = outputs.ids_restore mask = outputs.mask - decoder_outputs = self.decoder(latent, ids_restore) + decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding) logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels) - loss = self.forward_loss(pixel_values, logits, mask) + loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding) if not return_dict: output = (logits, mask, ids_restore) + outputs[2:] diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py index 5c27e5ac80..5344c3187c 100644 --- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py @@ -426,7 +426,7 @@ def prepare_img(): class TFViTMAEModelIntegrationTest(unittest.TestCase): @cached_property def default_image_processor(self): - return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None + return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") @slow def test_inference_for_pretraining(self): @@ -457,3 +457,32 @@ class TFViTMAEModelIntegrationTest(unittest.TestCase): ) tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViTMAE models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + + # make random mask reproducible across the PT and TF model + np.random.seed(2) + + model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base") + + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(images=image, do_resize=False, return_tensors="tf") + + # prepare a noise vector that will be also used for testing the TF model + # (this way we can ensure that the PT and TF models operate on the same inputs) + vit_mae_config = ViTMAEConfig() + num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size) + noise = np.random.uniform(size=(1, num_patches)) + + # forward pass + outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = tf.convert_to_tensor([1, 1200, 768]) + self.assertEqual(outputs.logits.shape, expected_shape) diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index 0357a3ebda..506660f089 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -296,7 +296,7 @@ def prepare_img(): class ViTMAEModelIntegrationTest(unittest.TestCase): @cached_property def default_image_processor(self): - return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None + return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") @slow def test_inference_for_pretraining(self): @@ -328,3 +328,35 @@ class ViTMAEModelIntegrationTest(unittest.TestCase): ) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViTMAE models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + + # make random mask reproducible across the PT and TF model + np.random.seed(2) + + model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device) + + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt", do_resize=False).to(torch_device) + + # prepare a noise vector that will be also used for testing the TF model + # (this way we can ensure that the PT and TF models operate on the same inputs) + vit_mae_config = ViTMAEConfig() + num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size) + noise = np.random.uniform(size=(1, num_patches)) + + # forward pass + with torch.no_grad(): + outputs = model( + **inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True + ) + + # verify the logits + expected_shape = torch.Size((1, 1200, 768)) + self.assertEqual(outputs.logits.shape, expected_shape)