[ViTMAE] Fix docstrings and variable names (#17710)
* Fix docstrings and variable names * Rename x to something better * Improve messages * Fix docstrings and add test for greyscale images Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -140,6 +140,15 @@ class TFViTMAEModelTester:
|
||||
expected_num_channels = self.patch_size**2 * self.num_channels
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
|
||||
|
||||
# test greyscale images
|
||||
config.num_channels = 1
|
||||
model = TFViTMAEForPreTraining(config)
|
||||
|
||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||
result = model(pixel_values, training=False)
|
||||
expected_num_channels = self.patch_size**2
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, pixel_values, labels) = config_and_inputs
|
||||
|
||||
@@ -137,6 +137,16 @@ class ViTMAEModelTester:
|
||||
expected_num_channels = self.patch_size**2 * self.num_channels
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
|
||||
|
||||
# test greyscale images
|
||||
config.num_channels = 1
|
||||
model = ViTMAEForPreTraining(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||
result = model(pixel_values)
|
||||
expected_num_channels = self.patch_size**2
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
|
||||
Reference in New Issue
Block a user