Add SimMIM (#15586)

* Add first draft

* Make model importable

* Make SwinForMaskedImageModeling importable

* Fix imports

* Add missing inits

* Add support for Swin

* Fix bug

* Fix bug

* Fix another bug

* Fix Swin MIM implementation

* Fix default encoder stride

* Fix Swin

* Add print statements for debugging

* Add image_size data argument

* Fix Swin

* Fix image_size

* Add print statements for debugging

* Fix print statement

* Remove print statements

* Improve reshaping of bool_masked_pos

* Add support for DeiT, fix tests

* Improve docstrings

* Apply new black version

* Improve script

* Fix bug

* Improve README

* Apply suggestions from code review

* Remove DS_Store and add to gitignore

* Apply suggestions from code review + fix BEiT Flax

* Revert BEiT changes

* Improve README

* Fix code quality

* Improve README

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-02-17 19:44:55 +01:00
committed by GitHub
parent 426b96230a
commit 57882177be
26 changed files with 1075 additions and 51 deletions

View File

@@ -72,6 +72,7 @@ if is_torch_available():
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
@@ -165,6 +166,11 @@ class ModelTesterMixin:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
elif model_class in get_values(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING):
num_patches = self.model_tester.image_size // self.model_tester.patch_size
inputs_dict["bool_masked_pos"] = torch.zeros(
(self.model_tester.batch_size, num_patches**2), dtype=torch.long, device=torch_device
)
return inputs_dict
def test_save_load(self):

View File

@@ -35,6 +35,7 @@ if is_torch_available():
MODEL_MAPPING,
DeiTForImageClassification,
DeiTForImageClassificationWithTeacher,
DeiTForMaskedImageModeling,
DeiTModel,
)
from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
@@ -67,6 +68,7 @@ class DeiTModelTester:
initializer_range=0.02,
num_labels=3,
scope=None,
encoder_stride=2,
):
self.parent = parent
self.batch_size = batch_size
@@ -85,6 +87,7 @@ class DeiTModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -111,6 +114,7 @@ class DeiTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
)
def create_and_check_model(self, config, pixel_values, labels):
@@ -155,6 +159,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
DeiTModel,
DeiTForImageClassification,
DeiTForImageClassificationWithTeacher,
DeiTForMaskedImageModeling,
)
if is_torch_available()
else ()

View File

@@ -31,7 +31,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import SwinForImageClassification, SwinModel
from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
if is_vision_available():
@@ -74,6 +74,7 @@ class SwinModelTester:
scope=None,
use_labels=True,
type_sequence_label_size=10,
encoder_stride=2,
):
self.parent = parent
self.batch_size = batch_size
@@ -98,6 +99,7 @@ class SwinModelTester:
self.scope = scope
self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -129,6 +131,7 @@ class SwinModelTester:
path_norm=self.patch_norm,
layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
)
def create_and_check_model(self, config, pixel_values, labels):
@@ -169,6 +172,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
(
SwinModel,
SwinForImageClassification,
SwinForMaskedImageModeling,
)
if is_torch_available()
else ()

View File

@@ -30,7 +30,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import ViTForImageClassification, ViTModel
from transformers import ViTForImageClassification, ViTForMaskedImageModeling, ViTModel
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
@@ -61,6 +61,7 @@ class ViTModelTester:
initializer_range=0.02,
num_labels=3,
scope=None,
encoder_stride=2,
):
self.parent = parent
self.batch_size = batch_size
@@ -79,6 +80,7 @@ class ViTModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -105,6 +107,7 @@ class ViTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
)
def create_and_check_model(self, config, pixel_values, labels):
@@ -148,6 +151,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
(
ViTModel,
ViTForImageClassification,
ViTForMaskedImageModeling,
)
if is_torch_available()
else ()