added interpolation for vitmae model in pytorch as well as tf. (#30732)
* added interpolation for vitmae model in pytorch as well as tf. * Update modeling_vit_mae.py irreugalr import fixed * small changes and proper formatting * changes suggested in review. * modified decoder interpolate_func * arguments and docstring fix * Apply suggestions from code review doc fixes Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -240,6 +240,38 @@ class TFViTMAEEmbeddings(keras.layers.Layer):
|
|||||||
with tf.name_scope(self.patch_embeddings.name):
|
with tf.name_scope(self.patch_embeddings.name):
|
||||||
self.patch_embeddings.build(None)
|
self.patch_embeddings.build(None)
|
||||||
|
|
||||||
|
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size, seq_len, dim = shape_list(embeddings)
|
||||||
|
num_patches = seq_len - 1
|
||||||
|
|
||||||
|
_, num_positions, _ = shape_list(self.position_embeddings)
|
||||||
|
num_positions -= 1
|
||||||
|
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return self.position_embeddings
|
||||||
|
class_pos_embed = self.position_embeddings[:, :1]
|
||||||
|
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||||
|
h0 = height // self.config.patch_size
|
||||||
|
w0 = width // self.config.patch_size
|
||||||
|
patch_pos_embed = tf.image.resize(
|
||||||
|
images=tf.reshape(
|
||||||
|
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||||
|
),
|
||||||
|
size=(h0, w0),
|
||||||
|
method="bicubic",
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
|
||||||
|
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
|
||||||
|
|
||||||
def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
|
def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
|
||||||
"""
|
"""
|
||||||
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
||||||
@@ -281,17 +313,23 @@ class TFViTMAEEmbeddings(keras.layers.Layer):
|
|||||||
|
|
||||||
return sequence_unmasked, mask, ids_restore
|
return sequence_unmasked, mask, ids_restore
|
||||||
|
|
||||||
def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
|
def call(
|
||||||
embeddings = self.patch_embeddings(pixel_values)
|
self, pixel_values: tf.Tensor, noise: tf.Tensor = None, interpolate_pos_encoding: bool = False
|
||||||
|
) -> tf.Tensor:
|
||||||
|
batch_size, num_channels, height, width = shape_list(pixel_values)
|
||||||
|
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
|
position_embeddings = self.position_embeddings
|
||||||
# add position embeddings w/o cls token
|
# add position embeddings w/o cls token
|
||||||
embeddings = embeddings + self.position_embeddings[:, 1:, :]
|
embeddings = embeddings + position_embeddings[:, 1:, :]
|
||||||
|
|
||||||
# masking: length -> length * config.mask_ratio
|
# masking: length -> length * config.mask_ratio
|
||||||
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
|
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
|
||||||
|
|
||||||
# append cls token
|
# append cls token
|
||||||
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
|
cls_token = self.cls_token + position_embeddings[:, :1, :]
|
||||||
cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
|
cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
|
||||||
embeddings = tf.concat([cls_tokens, embeddings], axis=1)
|
embeddings = tf.concat([cls_tokens, embeddings], axis=1)
|
||||||
|
|
||||||
@@ -329,7 +367,9 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer):
|
|||||||
name="projection",
|
name="projection",
|
||||||
)
|
)
|
||||||
|
|
||||||
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
|
def call(
|
||||||
|
self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False
|
||||||
|
) -> tf.Tensor:
|
||||||
batch_size, num_channels, height, width = shape_list(pixel_values)
|
batch_size, num_channels, height, width = shape_list(pixel_values)
|
||||||
if tf.executing_eagerly():
|
if tf.executing_eagerly():
|
||||||
if num_channels != self.num_channels:
|
if num_channels != self.num_channels:
|
||||||
@@ -337,7 +377,7 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer):
|
|||||||
"Make sure that the channel dimension of the pixel values match with the one set in the"
|
"Make sure that the channel dimension of the pixel values match with the one set in the"
|
||||||
" configuration."
|
" configuration."
|
||||||
)
|
)
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model"
|
f"Input image size ({height}*{width}) doesn't match model"
|
||||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
@@ -741,9 +781,13 @@ class TFViTMAEMainLayer(keras.layers.Layer):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
|
||||||
embedding_output, mask, ids_restore = self.embeddings(
|
embedding_output, mask, ids_restore = self.embeddings(
|
||||||
pixel_values=pixel_values, training=training, noise=noise
|
pixel_values=pixel_values,
|
||||||
|
training=training,
|
||||||
|
noise=noise,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
@@ -874,6 +918,9 @@ VIT_MAE_INPUTS_DOCSTRING = r"""
|
|||||||
training (`bool`, *optional*, defaults to `False``):
|
training (`bool`, *optional*, defaults to `False``):
|
||||||
Whether or not to use the model in training mode (some modules like dropout modules have different
|
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||||
behaviors between training and evaluation).
|
behaviors between training and evaluation).
|
||||||
|
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate the position encodings at the encoder and decoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -902,6 +949,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -931,6 +979,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
training=training,
|
training=training,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
@@ -1004,6 +1053,39 @@ class TFViTMAEDecoder(keras.layers.Layer):
|
|||||||
with tf.name_scope(layer.name):
|
with tf.name_scope(layer.name):
|
||||||
layer.build(None)
|
layer.build(None)
|
||||||
|
|
||||||
|
def interpolate_pos_encoding(self, embeddings) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
|
||||||
|
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
|
||||||
|
# [batch_size, num_patches + 1, hidden_size]
|
||||||
|
_, num_positions, dim = shape_list(self.decoder_pos_embed)
|
||||||
|
|
||||||
|
# -1 removes the class dimension since we later append it without interpolation
|
||||||
|
seq_len = shape_list(embeddings)[1] - 1
|
||||||
|
num_positions = num_positions - 1
|
||||||
|
|
||||||
|
# Separation of class token and patch tokens
|
||||||
|
class_pos_embed = self.decoder_pos_embed[:, :1, :]
|
||||||
|
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
|
||||||
|
|
||||||
|
# interpolate the position embeddings
|
||||||
|
patch_pos_embed = tf.image.resize(
|
||||||
|
images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)),
|
||||||
|
size=(1, seq_len),
|
||||||
|
method="bicubic",
|
||||||
|
)
|
||||||
|
|
||||||
|
# [1, seq_len, hidden_size]
|
||||||
|
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
|
||||||
|
# Adding the class token back
|
||||||
|
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -1011,10 +1093,10 @@ class TFViTMAEDecoder(keras.layers.Layer):
|
|||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
interpolate_pos_encoding=False,
|
||||||
):
|
):
|
||||||
# embed tokens
|
# embed tokens
|
||||||
x = self.decoder_embed(hidden_states)
|
x = self.decoder_embed(hidden_states)
|
||||||
|
|
||||||
# append mask tokens to sequence
|
# append mask tokens to sequence
|
||||||
mask_tokens = tf.tile(
|
mask_tokens = tf.tile(
|
||||||
self.mask_token,
|
self.mask_token,
|
||||||
@@ -1023,10 +1105,12 @@ class TFViTMAEDecoder(keras.layers.Layer):
|
|||||||
x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token
|
x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token
|
||||||
x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle
|
x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle
|
||||||
x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token
|
x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
decoder_pos_embed = self.interpolate_pos_encoding(x)
|
||||||
|
else:
|
||||||
|
decoder_pos_embed = self.decoder_pos_embed
|
||||||
# add pos embed
|
# add pos embed
|
||||||
hidden_states = x + self.decoder_pos_embed
|
hidden_states = x + decoder_pos_embed
|
||||||
|
|
||||||
# apply Transformer layers (blocks)
|
# apply Transformer layers (blocks)
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
@@ -1083,11 +1167,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
|||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def patchify(self, pixel_values):
|
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
|
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
|
||||||
Pixel values.
|
Pixel values.
|
||||||
|
interpolate_pos_encoding (`bool`, default `False`):
|
||||||
|
interpolation flag passed during the forward pass.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
||||||
@@ -1099,6 +1185,7 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
|||||||
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
||||||
|
|
||||||
# sanity checks
|
# sanity checks
|
||||||
|
if not interpolate_pos_encoding:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(pixel_values)[1],
|
shape_list(pixel_values)[1],
|
||||||
shape_list(pixel_values)[2],
|
shape_list(pixel_values)[2],
|
||||||
@@ -1119,51 +1206,61 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
|||||||
|
|
||||||
# patchify
|
# patchify
|
||||||
batch_size = shape_list(pixel_values)[0]
|
batch_size = shape_list(pixel_values)[0]
|
||||||
num_patches_one_direction = shape_list(pixel_values)[2] // patch_size
|
num_patches_h = shape_list(pixel_values)[1] // patch_size
|
||||||
|
num_patches_w = shape_list(pixel_values)[2] // patch_size
|
||||||
patchified_pixel_values = tf.reshape(
|
patchified_pixel_values = tf.reshape(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
(batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels),
|
(batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels),
|
||||||
)
|
)
|
||||||
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
|
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
|
||||||
patchified_pixel_values = tf.reshape(
|
patchified_pixel_values = tf.reshape(
|
||||||
patchified_pixel_values,
|
patchified_pixel_values,
|
||||||
(batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels),
|
(batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels),
|
||||||
)
|
)
|
||||||
return patchified_pixel_values
|
return patchified_pixel_values
|
||||||
|
|
||||||
def unpatchify(self, patchified_pixel_values):
|
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
||||||
Patchified pixel values.
|
Patchified pixel values.
|
||||||
|
original_image_size (`Tuple[int, int]`, *optional*):
|
||||||
|
Original image size.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
|
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
|
||||||
Pixel values.
|
Pixel values.
|
||||||
"""
|
"""
|
||||||
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
||||||
num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5)
|
original_image_size = (
|
||||||
|
original_image_size
|
||||||
|
if original_image_size is not None
|
||||||
|
else (self.config.image_size, self.config.image_size)
|
||||||
|
)
|
||||||
|
original_height, original_width = original_image_size
|
||||||
|
num_patches_h = original_height // patch_size
|
||||||
|
num_patches_w = original_width // patch_size
|
||||||
# sanity check
|
# sanity check
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
num_patches_one_direction * num_patches_one_direction,
|
num_patches_h * num_patches_w,
|
||||||
shape_list(patchified_pixel_values)[1],
|
shape_list(patchified_pixel_values)[1],
|
||||||
message="Make sure that the number of patches can be squared",
|
message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
batch_size = shape_list(patchified_pixel_values)[0]
|
batch_size = shape_list(patchified_pixel_values)[0]
|
||||||
patchified_pixel_values = tf.reshape(
|
patchified_pixel_values = tf.reshape(
|
||||||
patchified_pixel_values,
|
patchified_pixel_values,
|
||||||
(batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels),
|
(batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels),
|
||||||
)
|
)
|
||||||
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
|
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
|
||||||
pixel_values = tf.reshape(
|
pixel_values = tf.reshape(
|
||||||
patchified_pixel_values,
|
patchified_pixel_values,
|
||||||
(batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels),
|
(batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels),
|
||||||
)
|
)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward_loss(self, pixel_values, pred, mask):
|
def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
|
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
|
||||||
@@ -1172,11 +1269,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
|||||||
Predicted pixel values.
|
Predicted pixel values.
|
||||||
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
||||||
Tensor indicating which patches are masked (1) and which are not (0).
|
Tensor indicating which patches are masked (1) and which are not (0).
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
||||||
|
interpolation flag passed during the forward pass.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`tf.Tensor`: Pixel reconstruction loss.
|
`tf.Tensor`: Pixel reconstruction loss.
|
||||||
"""
|
"""
|
||||||
target = self.patchify(pixel_values)
|
target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
if self.config.norm_pix_loss:
|
if self.config.norm_pix_loss:
|
||||||
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
|
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
|
||||||
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
|
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
|
||||||
@@ -1201,6 +1300,7 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
|
) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1234,16 +1334,18 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
training=training,
|
training=training,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
latent = outputs.last_hidden_state
|
latent = outputs.last_hidden_state
|
||||||
ids_restore = outputs.ids_restore
|
ids_restore = outputs.ids_restore
|
||||||
mask = outputs.mask
|
mask = outputs.mask
|
||||||
|
|
||||||
decoder_outputs = self.decoder(latent, ids_restore) # [batch_size, num_patches, patch_size**2*3]
|
# [batch_size, num_patches, patch_size**2*3]
|
||||||
|
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
logits = decoder_outputs.logits
|
logits = decoder_outputs.logits
|
||||||
|
|
||||||
loss = self.forward_loss(pixel_values, logits, mask)
|
loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits, mask, ids_restore) + outputs[2:]
|
output = (logits, mask, ids_restore) + outputs[2:]
|
||||||
|
|||||||
@@ -223,6 +223,41 @@ class ViTMAEEmbeddings(nn.Module):
|
|||||||
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
||||||
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
|
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
|
||||||
|
|
||||||
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
num_patches = embeddings.shape[1] - 1
|
||||||
|
num_positions = self.position_embeddings.shape[1] - 1
|
||||||
|
|
||||||
|
if num_patches == num_positions and height == width:
|
||||||
|
return self.position_embeddings
|
||||||
|
|
||||||
|
class_pos_embed = self.position_embeddings[:, 0, :]
|
||||||
|
patch_pos_embed = self.position_embeddings[:, 1:, :]
|
||||||
|
dim = embeddings.shape[-1]
|
||||||
|
h0 = height // self.config.patch_size
|
||||||
|
w0 = width // self.config.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
if int(h0) != patch_pos_embed.shape[-2] or int(w0) != patch_pos_embed.shape[-1]:
|
||||||
|
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||||
|
|
||||||
def random_masking(self, sequence, noise=None):
|
def random_masking(self, sequence, noise=None):
|
||||||
"""
|
"""
|
||||||
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
||||||
@@ -255,18 +290,22 @@ class ViTMAEEmbeddings(nn.Module):
|
|||||||
|
|
||||||
return sequence_unmasked, mask, ids_restore
|
return sequence_unmasked, mask, ids_restore
|
||||||
|
|
||||||
def forward(self, pixel_values, noise=None):
|
def forward(self, pixel_values, noise=None, interpolate_pos_encoding: bool = False):
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
embeddings = self.patch_embeddings(pixel_values)
|
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
|
if interpolate_pos_encoding:
|
||||||
|
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
|
||||||
|
else:
|
||||||
|
position_embeddings = self.position_embeddings
|
||||||
|
|
||||||
# add position embeddings w/o cls token
|
# add position embeddings w/o cls token
|
||||||
embeddings = embeddings + self.position_embeddings[:, 1:, :]
|
embeddings = embeddings + position_embeddings[:, 1:, :]
|
||||||
|
|
||||||
# masking: length -> length * config.mask_ratio
|
# masking: length -> length * config.mask_ratio
|
||||||
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
|
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
|
||||||
|
|
||||||
# append cls token
|
# append cls token
|
||||||
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
|
cls_token = self.cls_token + position_embeddings[:, :1, :]
|
||||||
cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
|
cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
|
||||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||||
|
|
||||||
@@ -294,13 +333,14 @@ class ViTMAEPatchEmbeddings(nn.Module):
|
|||||||
|
|
||||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
|
||||||
batch_size, num_channels, height, width = pixel_values.shape
|
batch_size, num_channels, height, width = pixel_values.shape
|
||||||
if num_channels != self.num_channels:
|
if num_channels != self.num_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||||
)
|
)
|
||||||
if height != self.image_size[0] or width != self.image_size[1]:
|
|
||||||
|
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||||
)
|
)
|
||||||
@@ -657,6 +697,9 @@ VIT_MAE_INPUTS_DOCSTRING = r"""
|
|||||||
more detail.
|
more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
||||||
|
Whether to interpolate the pre-trained position encodings. This is mainly used to use the model on higher
|
||||||
|
resolution images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -698,6 +741,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple, ViTMAEModelOutput]:
|
) -> Union[Tuple, ViTMAEModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -735,7 +779,9 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise)
|
embedding_output, mask, ids_restore = self.embeddings(
|
||||||
|
pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
@@ -785,6 +831,47 @@ class ViTMAEDecoder(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.initialize_weights(num_patches)
|
self.initialize_weights(num_patches)
|
||||||
|
|
||||||
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
|
||||||
|
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
|
||||||
|
resolution images.
|
||||||
|
|
||||||
|
Source:
|
||||||
|
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||||
|
"""
|
||||||
|
|
||||||
|
# -1 removes the class dimension since we later append it without interpolation
|
||||||
|
embeddings_positions = embeddings.shape[1] - 1
|
||||||
|
num_positions = self.decoder_pos_embed.shape[1] - 1
|
||||||
|
|
||||||
|
# Separation of class token and patch tokens
|
||||||
|
class_pos_embed = self.decoder_pos_embed[:, 0, :]
|
||||||
|
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
|
||||||
|
|
||||||
|
# To retain the final 3d tensor with the required dimensions
|
||||||
|
dim = self.decoder_pos_embed.shape[-1]
|
||||||
|
|
||||||
|
# Increasing a dimension to enable bicubic interpolation
|
||||||
|
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim)
|
||||||
|
|
||||||
|
# permute to bring the dimension to be interpolated, to the last
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
# Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
|
||||||
|
# 1 keeps the other dimension constant
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed,
|
||||||
|
scale_factor=(1, embeddings_positions / num_positions),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Converting back to the original shape
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
# Adding the class token back
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||||
|
|
||||||
def initialize_weights(self, num_patches):
|
def initialize_weights(self, num_patches):
|
||||||
# initialize (and freeze) position embeddings by sin-cos embedding
|
# initialize (and freeze) position embeddings by sin-cos embedding
|
||||||
decoder_pos_embed = get_2d_sincos_pos_embed(
|
decoder_pos_embed = get_2d_sincos_pos_embed(
|
||||||
@@ -802,6 +889,7 @@ class ViTMAEDecoder(nn.Module):
|
|||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
):
|
):
|
||||||
# embed tokens
|
# embed tokens
|
||||||
x = self.decoder_embed(hidden_states)
|
x = self.decoder_embed(hidden_states)
|
||||||
@@ -812,9 +900,12 @@ class ViTMAEDecoder(nn.Module):
|
|||||||
# unshuffle
|
# unshuffle
|
||||||
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
|
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
|
||||||
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
||||||
|
|
||||||
# add pos embed
|
# add pos embed
|
||||||
hidden_states = x + self.decoder_pos_embed
|
if interpolate_pos_encoding:
|
||||||
|
decoder_pos_embed = self.interpolate_pos_encoding(x)
|
||||||
|
else:
|
||||||
|
decoder_pos_embed = self.decoder_pos_embed
|
||||||
|
hidden_states = x + decoder_pos_embed
|
||||||
|
|
||||||
# apply Transformer layers (blocks)
|
# apply Transformer layers (blocks)
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
@@ -893,11 +984,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def patchify(self, pixel_values):
|
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
Pixel values.
|
Pixel values.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
||||||
|
interpolation flag passed during the forward pass.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
||||||
@@ -905,7 +998,9 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
||||||
# sanity checks
|
# sanity checks
|
||||||
if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0):
|
if not interpolate_pos_encoding and (
|
||||||
|
pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0
|
||||||
|
):
|
||||||
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
|
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
|
||||||
if pixel_values.shape[1] != num_channels:
|
if pixel_values.shape[1] != num_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -914,38 +1009,50 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
|||||||
|
|
||||||
# patchify
|
# patchify
|
||||||
batch_size = pixel_values.shape[0]
|
batch_size = pixel_values.shape[0]
|
||||||
num_patches_one_direction = pixel_values.shape[2] // patch_size
|
num_patches_h = pixel_values.shape[2] // patch_size
|
||||||
|
num_patches_w = pixel_values.shape[3] // patch_size
|
||||||
patchified_pixel_values = pixel_values.reshape(
|
patchified_pixel_values = pixel_values.reshape(
|
||||||
batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size
|
batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size
|
||||||
)
|
)
|
||||||
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
|
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
|
||||||
patchified_pixel_values = patchified_pixel_values.reshape(
|
patchified_pixel_values = patchified_pixel_values.reshape(
|
||||||
batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels
|
batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels
|
||||||
)
|
)
|
||||||
return patchified_pixel_values
|
return patchified_pixel_values
|
||||||
|
|
||||||
def unpatchify(self, patchified_pixel_values):
|
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
||||||
Patchified pixel values.
|
Patchified pixel values.
|
||||||
|
original_image_size (`Tuple[int, int]`, *optional*):
|
||||||
|
Original image size.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
||||||
Pixel values.
|
Pixel values.
|
||||||
"""
|
"""
|
||||||
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
||||||
num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5)
|
original_image_size = (
|
||||||
|
original_image_size
|
||||||
|
if original_image_size is not None
|
||||||
|
else (self.config.image_size, self.config.image_size)
|
||||||
|
)
|
||||||
|
original_height, original_width = original_image_size
|
||||||
|
num_patches_h = original_height // patch_size
|
||||||
|
num_patches_w = original_width // patch_size
|
||||||
# sanity check
|
# sanity check
|
||||||
if num_patches_one_direction**2 != patchified_pixel_values.shape[1]:
|
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
|
||||||
raise ValueError("Make sure that the number of patches can be squared")
|
raise ValueError(
|
||||||
|
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
|
||||||
|
)
|
||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
batch_size = patchified_pixel_values.shape[0]
|
batch_size = patchified_pixel_values.shape[0]
|
||||||
patchified_pixel_values = patchified_pixel_values.reshape(
|
patchified_pixel_values = patchified_pixel_values.reshape(
|
||||||
batch_size,
|
batch_size,
|
||||||
num_patches_one_direction,
|
num_patches_h,
|
||||||
num_patches_one_direction,
|
num_patches_w,
|
||||||
patch_size,
|
patch_size,
|
||||||
patch_size,
|
patch_size,
|
||||||
num_channels,
|
num_channels,
|
||||||
@@ -954,12 +1061,12 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
|||||||
pixel_values = patchified_pixel_values.reshape(
|
pixel_values = patchified_pixel_values.reshape(
|
||||||
batch_size,
|
batch_size,
|
||||||
num_channels,
|
num_channels,
|
||||||
num_patches_one_direction * patch_size,
|
num_patches_h * patch_size,
|
||||||
num_patches_one_direction * patch_size,
|
num_patches_w * patch_size,
|
||||||
)
|
)
|
||||||
return pixel_values
|
return pixel_values
|
||||||
|
|
||||||
def forward_loss(self, pixel_values, pred, mask):
|
def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
@@ -968,11 +1075,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
|||||||
Predicted pixel values.
|
Predicted pixel values.
|
||||||
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||||
Tensor indicating which patches are masked (1) and which are not (0).
|
Tensor indicating which patches are masked (1) and which are not (0).
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
||||||
|
interpolation flag passed during the forward pass.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`torch.FloatTensor`: Pixel reconstruction loss.
|
`torch.FloatTensor`: Pixel reconstruction loss.
|
||||||
"""
|
"""
|
||||||
target = self.patchify(pixel_values)
|
target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
if self.config.norm_pix_loss:
|
if self.config.norm_pix_loss:
|
||||||
mean = target.mean(dim=-1, keepdim=True)
|
mean = target.mean(dim=-1, keepdim=True)
|
||||||
var = target.var(dim=-1, keepdim=True)
|
var = target.var(dim=-1, keepdim=True)
|
||||||
@@ -980,7 +1089,6 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
|||||||
|
|
||||||
loss = (pred - target) ** 2
|
loss = (pred - target) ** 2
|
||||||
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
||||||
|
|
||||||
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@@ -994,6 +1102,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
interpolate_pos_encoding: bool = False,
|
||||||
) -> Union[Tuple, ViTMAEForPreTrainingOutput]:
|
) -> Union[Tuple, ViTMAEForPreTrainingOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1026,16 +1135,17 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
latent = outputs.last_hidden_state
|
latent = outputs.last_hidden_state
|
||||||
ids_restore = outputs.ids_restore
|
ids_restore = outputs.ids_restore
|
||||||
mask = outputs.mask
|
mask = outputs.mask
|
||||||
|
|
||||||
decoder_outputs = self.decoder(latent, ids_restore)
|
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
|
logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
|
||||||
|
|
||||||
loss = self.forward_loss(pixel_values, logits, mask)
|
loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits, mask, ids_restore) + outputs[2:]
|
output = (logits, mask, ids_restore) + outputs[2:]
|
||||||
|
|||||||
@@ -426,7 +426,7 @@ def prepare_img():
|
|||||||
class TFViTMAEModelIntegrationTest(unittest.TestCase):
|
class TFViTMAEModelIntegrationTest(unittest.TestCase):
|
||||||
@cached_property
|
@cached_property
|
||||||
def default_image_processor(self):
|
def default_image_processor(self):
|
||||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
|
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_inference_for_pretraining(self):
|
def test_inference_for_pretraining(self):
|
||||||
@@ -457,3 +457,32 @@ class TFViTMAEModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
|
tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_interpolate_pos_encoding(self):
|
||||||
|
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
|
||||||
|
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||||
|
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
||||||
|
# to visualize self-attention on higher resolution images.
|
||||||
|
|
||||||
|
# make random mask reproducible across the PT and TF model
|
||||||
|
np.random.seed(2)
|
||||||
|
|
||||||
|
model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
|
||||||
|
|
||||||
|
image_processor = self.default_image_processor
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = image_processor(images=image, do_resize=False, return_tensors="tf")
|
||||||
|
|
||||||
|
# prepare a noise vector that will be also used for testing the TF model
|
||||||
|
# (this way we can ensure that the PT and TF models operate on the same inputs)
|
||||||
|
vit_mae_config = ViTMAEConfig()
|
||||||
|
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
|
||||||
|
noise = np.random.uniform(size=(1, num_patches))
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True)
|
||||||
|
|
||||||
|
# verify the logits
|
||||||
|
expected_shape = tf.convert_to_tensor([1, 1200, 768])
|
||||||
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|||||||
@@ -296,7 +296,7 @@ def prepare_img():
|
|||||||
class ViTMAEModelIntegrationTest(unittest.TestCase):
|
class ViTMAEModelIntegrationTest(unittest.TestCase):
|
||||||
@cached_property
|
@cached_property
|
||||||
def default_image_processor(self):
|
def default_image_processor(self):
|
||||||
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
|
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_inference_for_pretraining(self):
|
def test_inference_for_pretraining(self):
|
||||||
@@ -328,3 +328,35 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_interpolate_pos_encoding(self):
|
||||||
|
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
|
||||||
|
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||||
|
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
||||||
|
# to visualize self-attention on higher resolution images.
|
||||||
|
|
||||||
|
# make random mask reproducible across the PT and TF model
|
||||||
|
np.random.seed(2)
|
||||||
|
|
||||||
|
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
|
||||||
|
|
||||||
|
image_processor = self.default_image_processor
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = image_processor(images=image, return_tensors="pt", do_resize=False).to(torch_device)
|
||||||
|
|
||||||
|
# prepare a noise vector that will be also used for testing the TF model
|
||||||
|
# (this way we can ensure that the PT and TF models operate on the same inputs)
|
||||||
|
vit_mae_config = ViTMAEConfig()
|
||||||
|
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
|
||||||
|
noise = np.random.uniform(size=(1, num_patches))
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(
|
||||||
|
**inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# verify the logits
|
||||||
|
expected_shape = torch.Size((1, 1200, 768))
|
||||||
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user