Add Swin backbone (#20769)

* Add Swin backbone

* Remove line

* Add code example

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-12-14 19:35:28 +01:00
committed by GitHub
parent 94f8e21c70
commit 67acb07e9e
10 changed files with 256 additions and 41 deletions

View File

@@ -30,7 +30,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
from transformers import SwinBackbone, SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@@ -66,6 +66,7 @@ class SwinModelTester:
use_labels=True,
type_sequence_label_size=10,
encoder_stride=8,
out_features=["stage1", "stage2"],
):
self.parent = parent
self.batch_size = batch_size
@@ -91,6 +92,7 @@ class SwinModelTester:
self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
self.out_features = out_features
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -123,6 +125,7 @@ class SwinModelTester:
layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
out_features=self.out_features,
)
def create_and_check_model(self, config, pixel_values, labels):
@@ -136,6 +139,33 @@ class SwinModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_backbone(self, config, pixel_values, labels):
model = SwinBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = SwinBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = SwinForMaskedImageModeling(config=config)
model.to(torch_device)
@@ -190,6 +220,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
SwinModel,
SwinBackbone,
SwinForImageClassification,
SwinForMaskedImageModeling,
)
@@ -222,6 +253,10 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
@@ -230,8 +265,12 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(reason="Swin does not use inputs_embeds")
def test_inputs_embeds(self):
# Swin does not use inputs_embeds
pass
@unittest.skip(reason="Swin Transformer does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
def test_model_common_attributes(self):
@@ -299,11 +338,8 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
else:
# also another +1 for reshaped_hidden_states
added_hidden_states = 2
# also another +1 for reshaped_hidden_states
added_hidden_states = 1 if model_class.__name__ == "SwinBackbone" else 2
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.attentions
@@ -344,17 +380,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
[num_patches, self.model_tester.embed_dim],
)
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
if not model_class.__name__ == "SwinBackbone":
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()