From 9ef46659da45f6b605873ca59124d03976990b33 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 22 Nov 2022 17:00:08 +0100 Subject: [PATCH] Improve backbone (#20380) Co-authored-by: Niels Rogge --- src/transformers/models/resnet/modeling_resnet.py | 13 ++++++------- tests/models/resnet/test_modeling_resnet.py | 7 ++++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index 0988e478dd..4d16bad993 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -440,13 +440,12 @@ class ResNetBackbone(ResNetPreTrainedModel): self.out_features = config.out_features - self.out_feature_channels = { - "stem": config.embedding_size, - "stage1": config.hidden_sizes[0], - "stage2": config.hidden_sizes[1], - "stage3": config.hidden_sizes[2], - "stage4": config.hidden_sizes[3], - } + out_feature_channels = {} + out_feature_channels["stem"] = config.embedding_size + for idx, stage in enumerate(self.stage_names[1:]): + out_feature_channels[stage] = config.hidden_sizes[idx] + + self.out_feature_channels = out_feature_channels # initialize weights and apply final processing self.post_init() diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py index f2e5048438..0c230d1657 100644 --- a/tests/models/resnet/test_modeling_resnet.py +++ b/tests/models/resnet/test_modeling_resnet.py @@ -55,7 +55,7 @@ class ResNetModelTester: hidden_act="relu", num_labels=3, scope=None, - out_features=["stage1", "stage2", "stage3", "stage4"], + out_features=["stage2", "stage3", "stage4"], ): self.parent = parent self.batch_size = batch_size @@ -121,10 +121,11 @@ class ResNetModelTester: # verify hidden states self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) - self.parent.assertListEqual(list(result.feature_maps[0].shape), [3, 10, 8, 8]) + self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4]) # verify channels - self.parent.assertListEqual(model.channels, config.hidden_sizes) + self.parent.assertEqual(len(model.channels), len(config.out_features)) + self.parent.assertListEqual(model.channels, config.hidden_sizes[1:]) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs()