Improve backbone (#20380)

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-11-22 17:00:08 +01:00
committed by GitHub
parent 5efd074af0
commit 9ef46659da
2 changed files with 10 additions and 10 deletions

View File

@@ -55,7 +55,7 @@ class ResNetModelTester:
hidden_act="relu",
num_labels=3,
scope=None,
out_features=["stage1", "stage2", "stage3", "stage4"],
out_features=["stage2", "stage3", "stage4"],
):
self.parent = parent
self.batch_size = batch_size
@@ -121,10 +121,11 @@ class ResNetModelTester:
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [3, 10, 8, 8])
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
# verify channels
self.parent.assertListEqual(model.channels, config.hidden_sizes)
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()