From b681e12d5963490d29c2a77ba7346ee050e46def Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 21 Jun 2022 15:56:00 +0200 Subject: [PATCH] [ViTMAE] Fix docstrings and variable names (#17710) * Fix docstrings and variable names * Rename x to something better * Improve messages * Fix docstrings and add test for greyscale images Co-authored-by: Niels Rogge --- .../models/vit_mae/modeling_tf_vit_mae.py | 118 +++++++++++++----- .../models/vit_mae/modeling_vit_mae.py | 105 +++++++++++----- .../vit_mae/test_modeling_tf_vit_mae.py | 9 ++ tests/models/vit_mae/test_modeling_vit_mae.py | 10 ++ 4 files changed, 185 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py index 803a7cccc7..24d84141c8 100644 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -84,7 +84,7 @@ class TFViTMAEDecoderOutput(ModelOutput): Class for TFViTMAEDecoder's outputs, with potential hidden states and attentions. Args: - logits (`tf.Tensor` of shape `(batch_size, patch_size ** 2 * num_channels)`): + logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`): Pixel reconstruction logits. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape @@ -109,7 +109,7 @@ class TFViTMAEForPreTrainingOutput(ModelOutput): Args: loss (`tf.Tensor` of shape `(1,)`): Pixel reconstruction loss. - logits (`tf.Tensor` of shape `(batch_size, patch_size ** 2 * num_channels)`): + logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`): Pixel reconstruction logits. mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): Tensor indicating which patches are masked (1) and which are not (0). @@ -969,50 +969,110 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): def _prune_heads(self, heads_to_prune): raise NotImplementedError - def patchify(self, imgs): + def patchify(self, pixel_values): """ - imgs: (batch_size, height, width, 3) x: (batch_size, num_patches, patch_size**2 *3) + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`): + Pixel values. + + Returns: + `tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Patchified pixel values. """ - imgs = tf.cond( - tf.math.equal(shape_list(imgs)[1], 3), lambda: tf.transpose(imgs, perm=(0, 2, 3, 1)), lambda: imgs + patch_size, num_channels = self.config.patch_size, self.config.num_channels + # make sure channels are last + pixel_values = tf.cond( + tf.math.equal(shape_list(pixel_values)[1], num_channels), + lambda: tf.transpose(pixel_values, perm=(0, 2, 3, 1)), + lambda: pixel_values, ) - p = self.vit.embeddings.patch_embeddings.patch_size[0] - tf.debugging.assert_equal(shape_list(imgs)[1], shape_list(imgs)[2]) - tf.debugging.assert_equal(shape_list(imgs)[1] % p, 0) + # sanity checks + tf.debugging.assert_equal( + shape_list(pixel_values)[1], + shape_list(pixel_values)[2], + message="Make sure the pixel values have a squared size", + ) + tf.debugging.assert_equal( + shape_list(pixel_values)[1] % patch_size, + 0, + message="Make sure the pixel values have a size that is divisible by the patch size", + ) + tf.debugging.assert_equal( + shape_list(pixel_values)[3], + num_channels, + message=( + "Make sure the number of channels of the pixel values is equal to the one set in the configuration" + ), + ) - h = w = shape_list(imgs)[2] // p - x = tf.reshape(imgs, (shape_list(imgs)[0], h, p, w, p, 3)) - x = tf.einsum("nhpwqc->nhwpqc", x) - x = tf.reshape(x, (shape_list(imgs)[0], h * w, p**2 * 3)) - return x + # patchify + batch_size = shape_list(pixel_values)[0] + num_patches_one_direction = shape_list(pixel_values)[2] // patch_size + patchified_pixel_values = tf.reshape( + pixel_values, + (batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels), + ) + patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values) + patchified_pixel_values = tf.reshape( + patchified_pixel_values, + (batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels), + ) + return patchified_pixel_values - def unpatchify(self, x): + def unpatchify(self, patchified_pixel_values): """ - x: (batch_size, num_patches, patch_size**2 *3) imgs: (batch_size, height, width, 3) - """ - p = self.vit.embeddings.patch_embeddings.patch_size[0] - h = w = int(shape_list(x)[1] ** 0.5) - tf.debugging.assert_equal(h * w, shape_list(x)[1]) + Args: + patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Patchified pixel values. - x = tf.reshape(x, (shape_list(x)[0], h, w, p, p, 3)) - x = tf.einsum("nhwpqc->nhpwqc", x) - imgs = tf.reshape(x, (shape_list(x)[0], h * p, h * p, 3)) - return imgs + Returns: + `tf.Tensor` of shape `(batch_size, height, width, num_channels)`: + Pixel values. + """ + patch_size, num_channels = self.config.patch_size, self.config.num_channels + num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5) + # sanity check + tf.debugging.assert_equal( + num_patches_one_direction * num_patches_one_direction, + shape_list(patchified_pixel_values)[1], + message="Make sure that the number of patches can be squared", + ) - def forward_loss(self, imgs, pred, mask): + # unpatchify + batch_size = shape_list(patchified_pixel_values)[0] + patchified_pixel_values = tf.reshape( + patchified_pixel_values, + (batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels), + ) + patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values) + pixel_values = tf.reshape( + patchified_pixel_values, + (batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels), + ) + return pixel_values + + def forward_loss(self, pixel_values, pred, mask): """ - imgs: [batch_size, height, width, 3] pred: [batch_size, num_patches, patch_size**2*3] mask: [N, L], 0 is keep, - 1 is remove, + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`): + Pixel values. + pred (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Predicted pixel values. + mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Tensor indicating which patches are masked (1) and which are not (0). + + Returns: + `tf.Tensor`: Pixel reconstruction loss. """ - target = self.patchify(imgs) + target = self.patchify(pixel_values) if self.config.norm_pix_loss: mean = tf.reduce_mean(target, axis=-1, keepdims=True) var = tf.math.reduce_variance(target, axis=-1, keepdims=True) target = (target - mean) / (var + 1.0e-6) ** 0.5 loss = (pred - target) ** 2 - loss = tf.reduce_mean(loss, axis=-1) # [N, L], mean loss per patch + loss = tf.reduce_mean(loss, axis=-1) # [batch_size, num_patches], mean loss per patch loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches return loss diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index f827978739..1226b6025c 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -86,7 +86,7 @@ class ViTMAEDecoderOutput(ModelOutput): Class for ViTMAEDecoder's outputs, with potential hidden states and attentions. Args: - logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`): + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`): Pixel reconstruction logits. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of @@ -111,7 +111,7 @@ class ViTMAEForPreTrainingOutput(ModelOutput): Args: loss (`torch.FloatTensor` of shape `(1,)`): Pixel reconstruction loss. - logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`): + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`): Pixel reconstruction logits. mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Tensor indicating which patches are masked (1) and which are not (0). @@ -868,37 +868,86 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - def patchify(self, imgs): + def patchify(self, pixel_values): """ - imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) - """ - p = self.vit.embeddings.patch_embeddings.patch_size[0] - assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. - h = w = imgs.shape[2] // p - x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) - x = torch.einsum("nchpwq->nhwpqc", x) - x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) - return x + Returns: + `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Patchified pixel values. + """ + patch_size, num_channels = self.config.patch_size, self.config.num_channels + # sanity checks + if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0): + raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size") + if pixel_values.shape[1] != num_channels: + raise ValueError( + "Make sure the number of channels of the pixel values is equal to the one set in the configuration" + ) - def unpatchify(self, x): - """ - x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) - """ - p = self.vit.embeddings.patch_embeddings.patch_size[0] - h = w = int(x.shape[1] ** 0.5) - assert h * w == x.shape[1] + # patchify + batch_size = pixel_values.shape[0] + num_patches_one_direction = pixel_values.shape[2] // patch_size + patchified_pixel_values = pixel_values.reshape( + batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size + ) + patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values) + patchified_pixel_values = patchified_pixel_values.reshape( + batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels + ) + return patchified_pixel_values - x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) - x = torch.einsum("nhwpqc->nchpwq", x) - imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) - return imgs + def unpatchify(self, patchified_pixel_values): + """ + Args: + patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Patchified pixel values. - def forward_loss(self, imgs, pred, mask): + Returns: + `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: + Pixel values. """ - imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove, + patch_size, num_channels = self.config.patch_size, self.config.num_channels + num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5) + # sanity check + if num_patches_one_direction**2 != patchified_pixel_values.shape[1]: + raise ValueError("Make sure that the number of patches can be squared") + + # unpatchify + batch_size = patchified_pixel_values.shape[0] + patchified_pixel_values = patchified_pixel_values.reshape( + batch_size, + num_patches_one_direction, + num_patches_one_direction, + patch_size, + patch_size, + num_channels, + ) + patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values) + pixel_values = patchified_pixel_values.reshape( + batch_size, + num_channels, + num_patches_one_direction * patch_size, + num_patches_one_direction * patch_size, + ) + return pixel_values + + def forward_loss(self, pixel_values, pred, mask): """ - target = self.patchify(imgs) + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Predicted pixel values. + mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Tensor indicating which patches are masked (1) and which are not (0). + + Returns: + `torch.FloatTensor`: Pixel reconstruction loss. + """ + target = self.patchify(pixel_values) if self.config.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) @@ -958,8 +1007,8 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ids_restore = outputs.ids_restore mask = outputs.mask - decoder_outputs = self.decoder(latent, ids_restore) # [N, L, p*p*3] - logits = decoder_outputs.logits + decoder_outputs = self.decoder(latent, ids_restore) + logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels) loss = self.forward_loss(pixel_values, logits, mask) diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py index cb54e29b80..7ce6a80098 100644 --- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py @@ -140,6 +140,15 @@ class TFViTMAEModelTester: expected_num_channels = self.patch_size**2 * self.num_channels self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) + # test greyscale images + config.num_channels = 1 + model = TFViTMAEForPreTraining(config) + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values, training=False) + expected_num_channels = self.patch_size**2 + self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, pixel_values, labels) = config_and_inputs diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index 191984d82f..1a749b1628 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -137,6 +137,16 @@ class ViTMAEModelTester: expected_num_channels = self.patch_size**2 * self.num_channels self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) + # test greyscale images + config.num_channels = 1 + model = ViTMAEForPreTraining(config) + model.to(torch_device) + model.eval() + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + expected_num_channels = self.patch_size**2 + self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values, labels = config_and_inputs