[Backbones] Improve out features (#20675)
* Improve ResNet backbone * Improve Bit backbone * Improve docstrings * Fix default stage * Apply suggestions from code review Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -63,7 +63,7 @@ class BitConfig(PretrainedConfig):
|
|||||||
The width factor for the model.
|
The width factor for the model.
|
||||||
out_features (`List[str]`, *optional*):
|
out_features (`List[str]`, *optional*):
|
||||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||||
(depending on how many stages the model has).
|
(depending on how many stages the model has). Will default to the last stage if unset.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
|
|||||||
@@ -851,7 +851,7 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin):
|
|||||||
self.stage_names = config.stage_names
|
self.stage_names = config.stage_names
|
||||||
self.bit = BitModel(config)
|
self.bit = BitModel(config)
|
||||||
|
|
||||||
self.out_features = config.out_features
|
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||||
|
|
||||||
out_feature_channels = {}
|
out_feature_channels = {}
|
||||||
out_feature_channels["stem"] = config.embedding_size
|
out_feature_channels["stem"] = config.embedding_size
|
||||||
|
|||||||
@@ -69,7 +69,8 @@ class MaskFormerSwinConfig(PretrainedConfig):
|
|||||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||||
The epsilon used by the layer normalization layers.
|
The epsilon used by the layer normalization layers.
|
||||||
out_features (`List[str]`, *optional*):
|
out_features (`List[str]`, *optional*):
|
||||||
If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`.
|
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||||
|
(depending on how many stages the model has). Will default to the last stage if unset.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
|||||||
@@ -855,7 +855,7 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
|
|||||||
self.stage_names = config.stage_names
|
self.stage_names = config.stage_names
|
||||||
self.model = MaskFormerSwinModel(config)
|
self.model = MaskFormerSwinModel(config)
|
||||||
|
|
||||||
self.out_features = config.out_features
|
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||||
if "stem" in self.out_features:
|
if "stem" in self.out_features:
|
||||||
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
|
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
|
||||||
|
|
||||||
|
|||||||
@@ -59,8 +59,8 @@ class ResNetConfig(PretrainedConfig):
|
|||||||
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
|
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
|
||||||
If `True`, the first stage will downsample the inputs using a `stride` of 2.
|
If `True`, the first stage will downsample the inputs using a `stride` of 2.
|
||||||
out_features (`List[str]`, *optional*):
|
out_features (`List[str]`, *optional*):
|
||||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`,
|
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||||
`"stage3"`, `"stage4"`.
|
(depending on how many stages the model has). Will default to the last stage if unset.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
|
|||||||
nn.init.constant_(module.bias, 0)
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
if isinstance(module, (ResNetModel, ResNetBackbone)):
|
if isinstance(module, ResNetEncoder):
|
||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
|
||||||
@@ -439,7 +439,7 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
|
|||||||
self.embedder = ResNetEmbeddings(config)
|
self.embedder = ResNetEmbeddings(config)
|
||||||
self.encoder = ResNetEncoder(config)
|
self.encoder = ResNetEncoder(config)
|
||||||
|
|
||||||
self.out_features = config.out_features
|
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||||
|
|
||||||
out_feature_channels = {}
|
out_feature_channels = {}
|
||||||
out_feature_channels["stem"] = config.embedding_size
|
out_feature_channels["stem"] = config.embedding_size
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class BitModelTester:
|
|||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
|
|
||||||
# verify hidden states
|
# verify feature maps
|
||||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
|
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
|
||||||
|
|
||||||
@@ -127,6 +127,21 @@ class BitModelTester:
|
|||||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||||
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
|
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
|
||||||
|
|
||||||
|
# verify backbone works with out_features=None
|
||||||
|
config.out_features = None
|
||||||
|
model = BitBackbone(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
|
# verify feature maps
|
||||||
|
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||||
|
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
|
||||||
|
|
||||||
|
# verify channels
|
||||||
|
self.parent.assertEqual(len(model.channels), 1)
|
||||||
|
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
config, pixel_values, labels = config_and_inputs
|
config, pixel_values, labels = config_and_inputs
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class ResNetModelTester:
|
|||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
|
|
||||||
# verify hidden states
|
# verify feature maps
|
||||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
|
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
|
||||||
|
|
||||||
@@ -127,6 +127,21 @@ class ResNetModelTester:
|
|||||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||||
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
|
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
|
||||||
|
|
||||||
|
# verify backbone works with out_features=None
|
||||||
|
config.out_features = None
|
||||||
|
model = ResNetBackbone(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
|
# verify feature maps
|
||||||
|
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||||
|
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
|
||||||
|
|
||||||
|
# verify channels
|
||||||
|
self.parent.assertEqual(len(model.channels), 1)
|
||||||
|
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
config, pixel_values, labels = config_and_inputs
|
config, pixel_values, labels = config_and_inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user