[Hotfix] Fix Swin model outputs (#15414)

* Fix Swin model outputs

* Rename pooler
This commit is contained in:
NielsRogge
2022-01-31 16:32:14 +01:00
committed by GitHub
parent 38dfb40ae3
commit d4b3e56d64
2 changed files with 40 additions and 21 deletions

View File

@@ -137,9 +137,11 @@ class SwinModelTester:
model.eval()
result = model(pixel_values)
num_features = int(config.embed_dim * 2 ** (len(config.depths) - 1))
# since the model we're testing only consists of a single layer, expected_seq_len = number of patches
expected_seq_len = (config.image_size // config.patch_size) ** 2
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_features))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
@@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.2952, -0.4777, 0.2025]).to(torch_device)
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))