Improve backbone (#20380)
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -440,13 +440,12 @@ class ResNetBackbone(ResNetPreTrainedModel):
|
|||||||
|
|
||||||
self.out_features = config.out_features
|
self.out_features = config.out_features
|
||||||
|
|
||||||
self.out_feature_channels = {
|
out_feature_channels = {}
|
||||||
"stem": config.embedding_size,
|
out_feature_channels["stem"] = config.embedding_size
|
||||||
"stage1": config.hidden_sizes[0],
|
for idx, stage in enumerate(self.stage_names[1:]):
|
||||||
"stage2": config.hidden_sizes[1],
|
out_feature_channels[stage] = config.hidden_sizes[idx]
|
||||||
"stage3": config.hidden_sizes[2],
|
|
||||||
"stage4": config.hidden_sizes[3],
|
self.out_feature_channels = out_feature_channels
|
||||||
}
|
|
||||||
|
|
||||||
# initialize weights and apply final processing
|
# initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class ResNetModelTester:
|
|||||||
hidden_act="relu",
|
hidden_act="relu",
|
||||||
num_labels=3,
|
num_labels=3,
|
||||||
scope=None,
|
scope=None,
|
||||||
out_features=["stage1", "stage2", "stage3", "stage4"],
|
out_features=["stage2", "stage3", "stage4"],
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -121,10 +121,11 @@ class ResNetModelTester:
|
|||||||
|
|
||||||
# verify hidden states
|
# verify hidden states
|
||||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [3, 10, 8, 8])
|
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
|
||||||
|
|
||||||
# verify channels
|
# verify channels
|
||||||
self.parent.assertListEqual(model.channels, config.hidden_sizes)
|
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||||
|
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
|||||||
Reference in New Issue
Block a user