Swin support for any input size (#15986)
* padding done * correctly return one attention per layer * almost correct, attentions are not flatten one tuple per stage * tests green * doc * conversations * reshaping hidden_states * view in the test * reshape_hidden_states in Encoder and Model * new outputs with reshaped_hidden_states * conversations * doc * Update docs/source/model_doc/swin.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * conversations * fix tests * minor changes * resolved conversations * attentions one per stage * typo * typos * typos * function signature * CI * clean up tests Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
204c54d411
commit
667b823b89
@@ -230,15 +230,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = num_patches
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
@@ -248,8 +239,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), len(self.model_tester.depths))
|
||||
attentions = outputs.attentions
|
||||
expected_num_attentions = len(self.model_tester.depths)
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
@@ -260,19 +252,13 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), len(self.model_tester.depths))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
@@ -286,25 +272,19 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
# also another +1 for reshaped_hidden_states
|
||||
added_hidden_states = 2
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), len(self.model_tester.depths))
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
self.assertEqual(len(self_attentions), expected_num_attentions)
|
||||
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
@@ -315,7 +295,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
|
||||
@@ -325,6 +305,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# Swin has a different seq_length
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
|
||||
self.assertListEqual(
|
||||
@@ -332,6 +313,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
reshaped_hidden_states = outputs.reshaped_hidden_states
|
||||
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
|
||||
|
||||
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
|
||||
reshaped_hidden_states = (
|
||||
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(reshaped_hidden_states.shape[-2:]),
|
||||
[num_patches, self.model_tester.embed_dim],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -395,7 +388,5 @@ class SwinModelIntegrationTest(unittest.TestCase):
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user