[ViTMAE] Fix docstrings and variable names (#17710)

* Fix docstrings and variable names

* Rename x to something better

* Improve messages

* Fix docstrings and add test for greyscale images

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-06-21 15:56:00 +02:00
committed by GitHub
parent 3fab17fce8
commit b681e12d59
4 changed files with 185 additions and 57 deletions

View File

@@ -140,6 +140,15 @@ class TFViTMAEModelTester:
expected_num_channels = self.patch_size**2 * self.num_channels
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
# test greyscale images
config.num_channels = 1
model = TFViTMAEForPreTraining(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values, training=False)
expected_num_channels = self.patch_size**2
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, pixel_values, labels) = config_and_inputs

View File

@@ -137,6 +137,16 @@ class ViTMAEModelTester:
expected_num_channels = self.patch_size**2 * self.num_channels
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
# test greyscale images
config.num_channels = 1
model = ViTMAEForPreTraining(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
expected_num_channels = self.patch_size**2
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs