Add SimMIM (#15586)
* Add first draft * Make model importable * Make SwinForMaskedImageModeling importable * Fix imports * Add missing inits * Add support for Swin * Fix bug * Fix bug * Fix another bug * Fix Swin MIM implementation * Fix default encoder stride * Fix Swin * Add print statements for debugging * Add image_size data argument * Fix Swin * Fix image_size * Add print statements for debugging * Fix print statement * Remove print statements * Improve reshaping of bool_masked_pos * Add support for DeiT, fix tests * Improve docstrings * Apply new black version * Improve script * Fix bug * Improve README * Apply suggestions from code review * Remove DS_Store and add to gitignore * Apply suggestions from code review + fix BEiT Flax * Revert BEiT changes * Improve README * Fix code quality * Improve README Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -72,6 +72,7 @@ if is_torch_available():
|
||||
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@@ -165,6 +166,11 @@ class ModelTesterMixin:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in get_values(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING):
|
||||
num_patches = self.model_tester.image_size // self.model_tester.patch_size
|
||||
inputs_dict["bool_masked_pos"] = torch.zeros(
|
||||
(self.model_tester.batch_size, num_patches**2), dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def test_save_load(self):
|
||||
|
||||
@@ -35,6 +35,7 @@ if is_torch_available():
|
||||
MODEL_MAPPING,
|
||||
DeiTForImageClassification,
|
||||
DeiTForImageClassificationWithTeacher,
|
||||
DeiTForMaskedImageModeling,
|
||||
DeiTModel,
|
||||
)
|
||||
from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
||||
@@ -67,6 +68,7 @@ class DeiTModelTester:
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
encoder_stride=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -85,6 +87,7 @@ class DeiTModelTester:
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.encoder_stride = encoder_stride
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@@ -111,6 +114,7 @@ class DeiTModelTester:
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
@@ -155,6 +159,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
DeiTModel,
|
||||
DeiTForImageClassification,
|
||||
DeiTForImageClassificationWithTeacher,
|
||||
DeiTForMaskedImageModeling,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
|
||||
@@ -31,7 +31,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import SwinForImageClassification, SwinModel
|
||||
from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
|
||||
from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
||||
|
||||
if is_vision_available():
|
||||
@@ -74,6 +74,7 @@ class SwinModelTester:
|
||||
scope=None,
|
||||
use_labels=True,
|
||||
type_sequence_label_size=10,
|
||||
encoder_stride=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -98,6 +99,7 @@ class SwinModelTester:
|
||||
self.scope = scope
|
||||
self.use_labels = use_labels
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.encoder_stride = encoder_stride
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@@ -129,6 +131,7 @@ class SwinModelTester:
|
||||
path_norm=self.patch_norm,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
@@ -169,6 +172,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
(
|
||||
SwinModel,
|
||||
SwinForImageClassification,
|
||||
SwinForMaskedImageModeling,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
|
||||
@@ -30,7 +30,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import ViTForImageClassification, ViTModel
|
||||
from transformers import ViTForImageClassification, ViTForMaskedImageModeling, ViTModel
|
||||
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ class ViTModelTester:
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
encoder_stride=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -79,6 +80,7 @@ class ViTModelTester:
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.encoder_stride = encoder_stride
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@@ -105,6 +107,7 @@ class ViTModelTester:
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
encoder_stride=self.encoder_stride,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
@@ -148,6 +151,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
(
|
||||
ViTModel,
|
||||
ViTForImageClassification,
|
||||
ViTForMaskedImageModeling,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
|
||||
Reference in New Issue
Block a user