Create MaskedImageCompletionOutput and fix ViT docs (#22152)

* create MaskedImageCompletionOutput

* fix bugs

* fix bugs
This commit is contained in:
Alara Dirik
2023-03-14 16:55:18 +03:00
committed by GitHub
parent b45192ec47
commit 3b22bfbc6a
3 changed files with 40 additions and 7 deletions

View File

@@ -134,7 +134,7 @@ class ViTModelTester:
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)
# test greyscale images
@@ -145,7 +145,7 @@ class ViTModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size