Add SimMIM (#15586)
* Add first draft * Make model importable * Make SwinForMaskedImageModeling importable * Fix imports * Add missing inits * Add support for Swin * Fix bug * Fix bug * Fix another bug * Fix Swin MIM implementation * Fix default encoder stride * Fix Swin * Add print statements for debugging * Add image_size data argument * Fix Swin * Fix image_size * Add print statements for debugging * Fix print statement * Remove print statements * Improve reshaping of bool_masked_pos * Add support for DeiT, fix tests * Improve docstrings * Apply new black version * Improve script * Fix bug * Improve README * Apply suggestions from code review * Remove DS_Store and add to gitignore * Apply suggestions from code review + fix BEiT Flax * Revert BEiT changes * Improve README * Fix code quality * Improve README Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -72,6 +72,7 @@ if is_torch_available():
|
||||
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@@ -165,6 +166,11 @@ class ModelTesterMixin:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in get_values(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING):
|
||||
num_patches = self.model_tester.image_size // self.model_tester.patch_size
|
||||
inputs_dict["bool_masked_pos"] = torch.zeros(
|
||||
(self.model_tester.batch_size, num_patches**2), dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def test_save_load(self):
|
||||
|
||||
Reference in New Issue
Block a user