Add DeiT (PyTorch) (#11056)
* First draft of deit * More improvements * Remove DeiTTokenizerFast from init * Conversion script works * Add DeiT to ViT conversion script * Add tests, add head model, add support for deit in vit conversion script * Update model checkpoint names * Update image_mean and image_std, set resample to bicubic * Improve docs * Docs improvements * Add DeiTForImageClassificationWithTeacher to init * Address comments by @sgugger * Improve feature extractors * Make fix-copies * Minor fixes * Address comments by @patil-suraj * All models uploaded * Fix tests * Remove labels argument from DeiTForImageClassificationWithTeacher * Fix-copies, style and quality * Fix tests * Fix typo * Multiple docs improvements * More docs fixes
This commit is contained in:
@@ -155,20 +155,10 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ViTModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ViTConfig, hidden_size=37)
|
||||
self.config_tester = ConfigTester(self, config_class=ViTConfig, has_text_modality=False, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
config = self.config_tester.config_class(**self.config_tester.inputs_dict)
|
||||
# we omit vocab_size since ViT does not use this
|
||||
self.config_tester.parent.assertTrue(hasattr(config, "hidden_size"))
|
||||
self.config_tester.parent.assertTrue(hasattr(config, "num_attention_heads"))
|
||||
self.config_tester.parent.assertTrue(hasattr(config, "num_hidden_layers"))
|
||||
|
||||
self.config_tester.create_and_test_config_to_json_string()
|
||||
self.config_tester.create_and_test_config_to_json_file()
|
||||
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||
self.config_tester.create_and_test_config_with_num_labels()
|
||||
self.config_tester.check_config_can_be_init_without_params()
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
# ViT does not use inputs_embeds
|
||||
@@ -351,10 +341,7 @@ class ViTModelIntegrationTest(unittest.TestCase):
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
# currently failing
|
||||
# see https://discuss.pytorch.org/t/runtimeerror-expected-object-of-scalar-type-double-but-got-scalar-type-float-for-argument-2-weight/38961/2
|
||||
outputs = model(inputs["pixel_values"])
|
||||
# outputs = model(**inputs)
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
|
||||
Reference in New Issue
Block a user