Improve vision models (#17731)

* Improve vision models

* Add a lot of improvements

* Remove to_2tuple from swin tests

* Fix TF Swin

* Fix more tests

* Fix copies

* Improve more models

* Fix ViTMAE test

* Add channel check for TF models

* Add proper channel check for TF models

* Apply suggestion from code review

* Apply suggestions from code review

* Add channel check for Flax models, apply suggestion

* Fix bug

* Add tests for greyscale images

* Add test for interpolation of pos encodigns

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-06-24 11:34:51 +02:00
committed by GitHub
parent 893ab12452
commit 0917870510
39 changed files with 801 additions and 916 deletions

View File

@@ -91,8 +91,7 @@ class FlaxViTModelTester(unittest.TestCase):
return config, pixel_values
def create_and_check_model(self, config, pixel_values, labels):
def create_and_check_model(self, config, pixel_values):
model = FlaxViTModel(config=config)
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
@@ -101,6 +100,19 @@ class FlaxViTModelTester(unittest.TestCase):
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values):
config.num_labels = self.type_sequence_label_size
model = FlaxViTForImageClassification(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = FlaxViTForImageClassification(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -123,7 +135,15 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
# We neeed to override this test because ViT's forward signature is different than text models.
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
# We need to override this test because ViT's forward signature is different than text models.
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()