Add Swinv2 backbone (#27742)

* First draft

* More improvements

* More improvements

* Make all tests pass

* Remove script

* Update image processor

* Address comments

* Use new gradient checkpointing method

* Convert checkpoints, add integration test

* Do not keep aspect ratio for now

* Set keep_aspect_ratio=False for beit, add integration test

* Remove print statement
This commit is contained in:
NielsRogge
2023-12-22 12:12:56 +01:00
committed by GitHub
parent 1ef86c4f56
commit c9fb250a25
14 changed files with 667 additions and 130 deletions

View File

@@ -14,12 +14,14 @@
# limitations under the License.
""" Testing suite for the PyTorch Swinv2 model. """
import collections
import inspect
import unittest
from transformers import Swinv2Config
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -29,7 +31,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import Swinv2ForImageClassification, Swinv2ForMaskedImageModeling, Swinv2Model
from transformers import Swinv2Backbone, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling, Swinv2Model
from transformers.models.swinv2.modeling_swinv2 import SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
@@ -65,6 +67,8 @@ class Swinv2ModelTester:
use_labels=True,
type_sequence_label_size=10,
encoder_stride=8,
out_features=["stage1", "stage2"],
out_indices=[1, 2],
):
self.parent = parent
self.batch_size = batch_size
@@ -90,6 +94,8 @@ class Swinv2ModelTester:
self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
self.out_features = out_features
self.out_indices = out_indices
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -122,6 +128,8 @@ class Swinv2ModelTester:
layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
out_features=self.out_features,
out_indices=self.out_indices,
)
def create_and_check_model(self, config, pixel_values, labels):
@@ -135,6 +143,33 @@ class Swinv2ModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_backbone(self, config, pixel_values, labels):
model = Swinv2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = Swinv2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = Swinv2ForMaskedImageModeling(config=config)
model.to(torch_device)
@@ -172,7 +207,14 @@ class Swinv2ModelTester:
@require_torch
class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(Swinv2Model, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling) if is_torch_available() else ()
(
Swinv2Model,
Swinv2ForImageClassification,
Swinv2ForMaskedImageModeling,
Swinv2Backbone,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"feature-extraction": Swinv2Model, "image-classification": Swinv2ForImageClassification}
@@ -201,6 +243,10 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
# TODO: check if this works again for PyTorch 2.x.y
@unittest.skip(reason="Got `CUDA error: misaligned address` with PyTorch 2.0.0.")
def test_multi_gpu_data_parallel_forward(self):
@@ -219,6 +265,18 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
@@ -263,11 +321,8 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
else:
# also another +1 for reshaped_hidden_states
added_hidden_states = 2
# also another +1 for reshaped_hidden_states
added_hidden_states = 1 if model_class.__name__ == "Swinv2Backbone" else 2
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.attentions
@@ -308,17 +363,18 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
[num_patches, self.model_tester.embed_dim],
)
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
if not model_class.__name__ == "Swinv2Backbone":
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -380,6 +436,10 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = Swinv2Model.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip(reason="Swinv2 does not support feedforward chunking yet")
def test_feed_forward_chunking(self):
pass
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -425,3 +485,12 @@ class Swinv2ModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@require_torch
class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin):
all_model_classes = (Swinv2Backbone,) if is_torch_available() else ()
config_class = Swinv2Config
def setUp(self):
self.model_tester = Swinv2ModelTester(self)