Backbone kwargs in config (#28784)
* Enable instantiating model with pretrained backbone weights * Clarify pretrained import * Use load_backbone instead * Add backbone_kwargs to config * Pass kwargs to constructors * Fix up * Input verification * Add tests * Tidy up * Update tests/utils/test_backbone_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -98,6 +98,9 @@ class ConditionalDetrConfig(PretrainedConfig):
|
|||||||
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
||||||
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to use pretrained weights for the backbone.
|
Whether to use pretrained weights for the backbone.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
dilation (`bool`, *optional*, defaults to `False`):
|
dilation (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
||||||
`use_timm_backbone` = `True`.
|
`use_timm_backbone` = `True`.
|
||||||
@@ -168,6 +171,7 @@ class ConditionalDetrConfig(PretrainedConfig):
|
|||||||
position_embedding_type="sine",
|
position_embedding_type="sine",
|
||||||
backbone="resnet50",
|
backbone="resnet50",
|
||||||
use_pretrained_backbone=True,
|
use_pretrained_backbone=True,
|
||||||
|
backbone_kwargs=None,
|
||||||
dilation=False,
|
dilation=False,
|
||||||
class_cost=2,
|
class_cost=2,
|
||||||
bbox_cost=5,
|
bbox_cost=5,
|
||||||
@@ -191,6 +195,9 @@ class ConditionalDetrConfig(PretrainedConfig):
|
|||||||
if backbone_config is not None and use_timm_backbone:
|
if backbone_config is not None and use_timm_backbone:
|
||||||
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
if not use_timm_backbone:
|
if not use_timm_backbone:
|
||||||
if backbone_config is None:
|
if backbone_config is None:
|
||||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
||||||
@@ -224,6 +231,7 @@ class ConditionalDetrConfig(PretrainedConfig):
|
|||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
# Hungarian matcher
|
# Hungarian matcher
|
||||||
self.class_cost = class_cost
|
self.class_cost = class_cost
|
||||||
|
|||||||
@@ -90,6 +90,9 @@ class DeformableDetrConfig(PretrainedConfig):
|
|||||||
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
||||||
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to use pretrained weights for the backbone.
|
Whether to use pretrained weights for the backbone.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
dilation (`bool`, *optional*, defaults to `False`):
|
dilation (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
||||||
`use_timm_backbone` = `True`.
|
`use_timm_backbone` = `True`.
|
||||||
@@ -177,6 +180,7 @@ class DeformableDetrConfig(PretrainedConfig):
|
|||||||
position_embedding_type="sine",
|
position_embedding_type="sine",
|
||||||
backbone="resnet50",
|
backbone="resnet50",
|
||||||
use_pretrained_backbone=True,
|
use_pretrained_backbone=True,
|
||||||
|
backbone_kwargs=None,
|
||||||
dilation=False,
|
dilation=False,
|
||||||
num_feature_levels=4,
|
num_feature_levels=4,
|
||||||
encoder_n_points=4,
|
encoder_n_points=4,
|
||||||
@@ -207,6 +211,9 @@ class DeformableDetrConfig(PretrainedConfig):
|
|||||||
if backbone_config is not None and use_timm_backbone:
|
if backbone_config is not None and use_timm_backbone:
|
||||||
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
if not use_timm_backbone:
|
if not use_timm_backbone:
|
||||||
if backbone_config is None:
|
if backbone_config is None:
|
||||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
||||||
@@ -238,6 +245,7 @@ class DeformableDetrConfig(PretrainedConfig):
|
|||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
# deformable attributes
|
# deformable attributes
|
||||||
self.num_feature_levels = num_feature_levels
|
self.num_feature_levels = num_feature_levels
|
||||||
|
|||||||
@@ -49,6 +49,9 @@ class DetaConfig(PretrainedConfig):
|
|||||||
use_timm_backbone (`bool`, *optional*, `False`):
|
use_timm_backbone (`bool`, *optional*, `False`):
|
||||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||||
library.
|
library.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
num_queries (`int`, *optional*, defaults to 900):
|
num_queries (`int`, *optional*, defaults to 900):
|
||||||
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
|
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
|
||||||
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
|
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
|
||||||
@@ -150,6 +153,7 @@ class DetaConfig(PretrainedConfig):
|
|||||||
backbone=None,
|
backbone=None,
|
||||||
use_pretrained_backbone=False,
|
use_pretrained_backbone=False,
|
||||||
use_timm_backbone=False,
|
use_timm_backbone=False,
|
||||||
|
backbone_kwargs=None,
|
||||||
num_queries=900,
|
num_queries=900,
|
||||||
max_position_embeddings=2048,
|
max_position_embeddings=2048,
|
||||||
encoder_layers=6,
|
encoder_layers=6,
|
||||||
@@ -204,10 +208,14 @@ class DetaConfig(PretrainedConfig):
|
|||||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||||
backbone_config = config_class.from_dict(backbone_config)
|
backbone_config = config_class.from_dict(backbone_config)
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
self.backbone_config = backbone_config
|
self.backbone_config = backbone_config
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.num_queries = num_queries
|
self.num_queries = num_queries
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
|||||||
@@ -98,6 +98,9 @@ class DetrConfig(PretrainedConfig):
|
|||||||
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
||||||
use_pretrained_backbone (`bool`, *optional*, `True`):
|
use_pretrained_backbone (`bool`, *optional*, `True`):
|
||||||
Whether to use pretrained weights for the backbone.
|
Whether to use pretrained weights for the backbone.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
dilation (`bool`, *optional*, defaults to `False`):
|
dilation (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
||||||
`use_timm_backbone` = `True`.
|
`use_timm_backbone` = `True`.
|
||||||
@@ -166,6 +169,7 @@ class DetrConfig(PretrainedConfig):
|
|||||||
position_embedding_type="sine",
|
position_embedding_type="sine",
|
||||||
backbone="resnet50",
|
backbone="resnet50",
|
||||||
use_pretrained_backbone=True,
|
use_pretrained_backbone=True,
|
||||||
|
backbone_kwargs=None,
|
||||||
dilation=False,
|
dilation=False,
|
||||||
class_cost=1,
|
class_cost=1,
|
||||||
bbox_cost=5,
|
bbox_cost=5,
|
||||||
@@ -188,6 +192,9 @@ class DetrConfig(PretrainedConfig):
|
|||||||
if backbone_config is not None and use_timm_backbone:
|
if backbone_config is not None and use_timm_backbone:
|
||||||
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
if not use_timm_backbone:
|
if not use_timm_backbone:
|
||||||
if backbone_config is None:
|
if backbone_config is None:
|
||||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
||||||
@@ -223,6 +230,7 @@ class DetrConfig(PretrainedConfig):
|
|||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
# Hungarian matcher
|
# Hungarian matcher
|
||||||
self.class_cost = class_cost
|
self.class_cost = class_cost
|
||||||
|
|||||||
@@ -120,6 +120,9 @@ class DPTConfig(PretrainedConfig):
|
|||||||
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||||
library.
|
library.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -173,6 +176,7 @@ class DPTConfig(PretrainedConfig):
|
|||||||
backbone=None,
|
backbone=None,
|
||||||
use_pretrained_backbone=False,
|
use_pretrained_backbone=False,
|
||||||
use_timm_backbone=False,
|
use_timm_backbone=False,
|
||||||
|
backbone_kwargs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -230,9 +234,13 @@ class DPTConfig(PretrainedConfig):
|
|||||||
if use_autobackbone and backbone_config is not None and backbone is not None:
|
if use_autobackbone and backbone_config is not None and backbone is not None:
|
||||||
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
|
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
|
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
|
||||||
self.num_attention_heads = None if use_autobackbone else num_attention_heads
|
self.num_attention_heads = None if use_autobackbone else num_attention_heads
|
||||||
self.intermediate_size = None if use_autobackbone else intermediate_size
|
self.intermediate_size = None if use_autobackbone else intermediate_size
|
||||||
|
|||||||
@@ -56,6 +56,9 @@ class Mask2FormerConfig(PretrainedConfig):
|
|||||||
use_timm_backbone (`bool`, *optional*, `False`):
|
use_timm_backbone (`bool`, *optional*, `False`):
|
||||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||||
library.
|
library.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
feature_size (`int`, *optional*, defaults to 256):
|
feature_size (`int`, *optional*, defaults to 256):
|
||||||
The features (channels) of the resulting feature maps.
|
The features (channels) of the resulting feature maps.
|
||||||
mask_feature_size (`int`, *optional*, defaults to 256):
|
mask_feature_size (`int`, *optional*, defaults to 256):
|
||||||
@@ -163,9 +166,10 @@ class Mask2FormerConfig(PretrainedConfig):
|
|||||||
use_auxiliary_loss: bool = True,
|
use_auxiliary_loss: bool = True,
|
||||||
feature_strides: List[int] = [4, 8, 16, 32],
|
feature_strides: List[int] = [4, 8, 16, 32],
|
||||||
output_auxiliary_logits: bool = None,
|
output_auxiliary_logits: bool = None,
|
||||||
backbone=None,
|
backbone: Optional[str] = None,
|
||||||
use_pretrained_backbone=False,
|
use_pretrained_backbone: bool = False,
|
||||||
use_timm_backbone=False,
|
use_timm_backbone: bool = False,
|
||||||
|
backbone_kwargs: Optional[Dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if use_pretrained_backbone:
|
if use_pretrained_backbone:
|
||||||
@@ -189,6 +193,9 @@ class Mask2FormerConfig(PretrainedConfig):
|
|||||||
out_features=["stage1", "stage2", "stage3", "stage4"],
|
out_features=["stage1", "stage2", "stage3", "stage4"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
if isinstance(backbone_config, dict):
|
if isinstance(backbone_config, dict):
|
||||||
backbone_model_type = backbone_config.pop("model_type")
|
backbone_model_type = backbone_config.pop("model_type")
|
||||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||||
@@ -233,6 +240,7 @@ class Mask2FormerConfig(PretrainedConfig):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ class MaskFormerConfig(PretrainedConfig):
|
|||||||
use_timm_backbone (`bool`, *optional*, `False`):
|
use_timm_backbone (`bool`, *optional*, `False`):
|
||||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||||
library.
|
library.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
decoder_config (`Dict`, *optional*):
|
decoder_config (`Dict`, *optional*):
|
||||||
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
|
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
|
||||||
will be used.
|
will be used.
|
||||||
@@ -126,6 +129,7 @@ class MaskFormerConfig(PretrainedConfig):
|
|||||||
backbone: Optional[str] = None,
|
backbone: Optional[str] = None,
|
||||||
use_pretrained_backbone: bool = False,
|
use_pretrained_backbone: bool = False,
|
||||||
use_timm_backbone: bool = False,
|
use_timm_backbone: bool = False,
|
||||||
|
backbone_kwargs: Optional[Dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if use_pretrained_backbone:
|
if use_pretrained_backbone:
|
||||||
@@ -134,6 +138,9 @@ class MaskFormerConfig(PretrainedConfig):
|
|||||||
if backbone_config is not None and backbone is not None:
|
if backbone_config is not None and backbone is not None:
|
||||||
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
|
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
if backbone_config is None and backbone is None:
|
if backbone_config is None and backbone is None:
|
||||||
# fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k
|
# fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k
|
||||||
backbone_config = SwinConfig(
|
backbone_config = SwinConfig(
|
||||||
@@ -198,6 +205,7 @@ class MaskFormerConfig(PretrainedConfig):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -53,6 +53,9 @@ class OneFormerConfig(PretrainedConfig):
|
|||||||
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||||
library.
|
library.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
ignore_value (`int`, *optional*, defaults to 255):
|
ignore_value (`int`, *optional*, defaults to 255):
|
||||||
Values to be ignored in GT label while calculating loss.
|
Values to be ignored in GT label while calculating loss.
|
||||||
num_queries (`int`, *optional*, defaults to 150):
|
num_queries (`int`, *optional*, defaults to 150):
|
||||||
@@ -156,6 +159,7 @@ class OneFormerConfig(PretrainedConfig):
|
|||||||
backbone: Optional[str] = None,
|
backbone: Optional[str] = None,
|
||||||
use_pretrained_backbone: bool = False,
|
use_pretrained_backbone: bool = False,
|
||||||
use_timm_backbone: bool = False,
|
use_timm_backbone: bool = False,
|
||||||
|
backbone_kwargs: Optional[Dict] = None,
|
||||||
ignore_value: int = 255,
|
ignore_value: int = 255,
|
||||||
num_queries: int = 150,
|
num_queries: int = 150,
|
||||||
no_object_weight: int = 0.1,
|
no_object_weight: int = 0.1,
|
||||||
@@ -223,10 +227,14 @@ class OneFormerConfig(PretrainedConfig):
|
|||||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||||
backbone_config = config_class.from_dict(backbone_config)
|
backbone_config = config_class.from_dict(backbone_config)
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
self.backbone_config = backbone_config
|
self.backbone_config = backbone_config
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.ignore_value = ignore_value
|
self.ignore_value = ignore_value
|
||||||
self.num_queries = num_queries
|
self.num_queries = num_queries
|
||||||
self.no_object_weight = no_object_weight
|
self.no_object_weight = no_object_weight
|
||||||
|
|||||||
@@ -98,6 +98,9 @@ class TableTransformerConfig(PretrainedConfig):
|
|||||||
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
||||||
use_pretrained_backbone (`bool`, *optional*, `True`):
|
use_pretrained_backbone (`bool`, *optional*, `True`):
|
||||||
Whether to use pretrained weights for the backbone.
|
Whether to use pretrained weights for the backbone.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
dilation (`bool`, *optional*, defaults to `False`):
|
dilation (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
|
||||||
`use_timm_backbone` = `True`.
|
`use_timm_backbone` = `True`.
|
||||||
@@ -167,6 +170,7 @@ class TableTransformerConfig(PretrainedConfig):
|
|||||||
position_embedding_type="sine",
|
position_embedding_type="sine",
|
||||||
backbone="resnet50",
|
backbone="resnet50",
|
||||||
use_pretrained_backbone=True,
|
use_pretrained_backbone=True,
|
||||||
|
backbone_kwargs=None,
|
||||||
dilation=False,
|
dilation=False,
|
||||||
class_cost=1,
|
class_cost=1,
|
||||||
bbox_cost=5,
|
bbox_cost=5,
|
||||||
@@ -189,6 +193,9 @@ class TableTransformerConfig(PretrainedConfig):
|
|||||||
if backbone_config is not None and use_timm_backbone:
|
if backbone_config is not None and use_timm_backbone:
|
||||||
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
if not use_timm_backbone:
|
if not use_timm_backbone:
|
||||||
if backbone_config is None:
|
if backbone_config is None:
|
||||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
||||||
@@ -224,6 +231,7 @@ class TableTransformerConfig(PretrainedConfig):
|
|||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
# Hungarian matcher
|
# Hungarian matcher
|
||||||
self.class_cost = class_cost
|
self.class_cost = class_cost
|
||||||
|
|||||||
@@ -52,6 +52,9 @@ class TvpConfig(PretrainedConfig):
|
|||||||
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||||
library.
|
library.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
distance_loss_weight (`float`, *optional*, defaults to 1.0):
|
distance_loss_weight (`float`, *optional*, defaults to 1.0):
|
||||||
The weight of distance loss.
|
The weight of distance loss.
|
||||||
duration_loss_weight (`float`, *optional*, defaults to 0.1):
|
duration_loss_weight (`float`, *optional*, defaults to 0.1):
|
||||||
@@ -107,6 +110,7 @@ class TvpConfig(PretrainedConfig):
|
|||||||
backbone=None,
|
backbone=None,
|
||||||
use_pretrained_backbone=False,
|
use_pretrained_backbone=False,
|
||||||
use_timm_backbone=False,
|
use_timm_backbone=False,
|
||||||
|
backbone_kwargs=None,
|
||||||
distance_loss_weight=1.0,
|
distance_loss_weight=1.0,
|
||||||
duration_loss_weight=0.1,
|
duration_loss_weight=0.1,
|
||||||
visual_prompter_type="framepad",
|
visual_prompter_type="framepad",
|
||||||
@@ -144,10 +148,14 @@ class TvpConfig(PretrainedConfig):
|
|||||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||||
backbone_config = config_class.from_dict(backbone_config)
|
backbone_config = config_class.from_dict(backbone_config)
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
self.backbone_config = backbone_config
|
self.backbone_config = backbone_config
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.distance_loss_weight = distance_loss_weight
|
self.distance_loss_weight = distance_loss_weight
|
||||||
self.duration_loss_weight = duration_loss_weight
|
self.duration_loss_weight = duration_loss_weight
|
||||||
self.visual_prompter_type = visual_prompter_type
|
self.visual_prompter_type = visual_prompter_type
|
||||||
|
|||||||
@@ -45,6 +45,9 @@ class UperNetConfig(PretrainedConfig):
|
|||||||
use_timm_backbone (`bool`, *optional*, `False`):
|
use_timm_backbone (`bool`, *optional*, `False`):
|
||||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||||
library.
|
library.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
hidden_size (`int`, *optional*, defaults to 512):
|
hidden_size (`int`, *optional*, defaults to 512):
|
||||||
The number of hidden units in the convolutional layers.
|
The number of hidden units in the convolutional layers.
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
@@ -87,6 +90,7 @@ class UperNetConfig(PretrainedConfig):
|
|||||||
backbone=None,
|
backbone=None,
|
||||||
use_pretrained_backbone=False,
|
use_pretrained_backbone=False,
|
||||||
use_timm_backbone=False,
|
use_timm_backbone=False,
|
||||||
|
backbone_kwargs=None,
|
||||||
hidden_size=512,
|
hidden_size=512,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
pool_scales=[1, 2, 3, 6],
|
pool_scales=[1, 2, 3, 6],
|
||||||
@@ -114,10 +118,14 @@ class UperNetConfig(PretrainedConfig):
|
|||||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||||
backbone_config = config_class.from_dict(backbone_config)
|
backbone_config = config_class.from_dict(backbone_config)
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
self.backbone_config = backbone_config
|
self.backbone_config = backbone_config
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.pool_scales = pool_scales
|
self.pool_scales = pool_scales
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ class ViTHybridConfig(PretrainedConfig):
|
|||||||
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||||
library.
|
library.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
hidden_size (`int`, *optional*, defaults to 768):
|
hidden_size (`int`, *optional*, defaults to 768):
|
||||||
Dimensionality of the encoder layers and the pooler layer.
|
Dimensionality of the encoder layers and the pooler layer.
|
||||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||||
@@ -104,6 +107,7 @@ class ViTHybridConfig(PretrainedConfig):
|
|||||||
backbone=None,
|
backbone=None,
|
||||||
use_pretrained_backbone=False,
|
use_pretrained_backbone=False,
|
||||||
use_timm_backbone=False,
|
use_timm_backbone=False,
|
||||||
|
backbone_kwargs=None,
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
num_hidden_layers=12,
|
num_hidden_layers=12,
|
||||||
num_attention_heads=12,
|
num_attention_heads=12,
|
||||||
@@ -137,6 +141,9 @@ class ViTHybridConfig(PretrainedConfig):
|
|||||||
"embedding_dynamic_padding": True,
|
"embedding_dynamic_padding": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
if isinstance(backbone_config, dict):
|
if isinstance(backbone_config, dict):
|
||||||
if "model_type" in backbone_config:
|
if "model_type" in backbone_config:
|
||||||
backbone_config_class = CONFIG_MAPPING[backbone_config["model_type"]]
|
backbone_config_class = CONFIG_MAPPING[backbone_config["model_type"]]
|
||||||
@@ -152,6 +159,7 @@ class ViTHybridConfig(PretrainedConfig):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ class VitMatteConfig(PretrainedConfig):
|
|||||||
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||||
library.
|
library.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
hidden_size (`int`, *optional*, defaults to 384):
|
hidden_size (`int`, *optional*, defaults to 384):
|
||||||
The number of input channels of the decoder.
|
The number of input channels of the decoder.
|
||||||
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||||
@@ -85,6 +88,7 @@ class VitMatteConfig(PretrainedConfig):
|
|||||||
backbone=None,
|
backbone=None,
|
||||||
use_pretrained_backbone=False,
|
use_pretrained_backbone=False,
|
||||||
use_timm_backbone=False,
|
use_timm_backbone=False,
|
||||||
|
backbone_kwargs=None,
|
||||||
hidden_size: int = 384,
|
hidden_size: int = 384,
|
||||||
batch_norm_eps: float = 1e-5,
|
batch_norm_eps: float = 1e-5,
|
||||||
initializer_range: float = 0.02,
|
initializer_range: float = 0.02,
|
||||||
@@ -108,10 +112,14 @@ class VitMatteConfig(PretrainedConfig):
|
|||||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||||
backbone_config = config_class.from_dict(backbone_config)
|
backbone_config = config_class.from_dict(backbone_config)
|
||||||
|
|
||||||
|
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
self.backbone_config = backbone_config
|
self.backbone_config = backbone_config
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.use_pretrained_backbone = use_pretrained_backbone
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
self.use_timm_backbone = use_timm_backbone
|
self.use_timm_backbone = use_timm_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
self.batch_norm_eps = batch_norm_eps
|
self.batch_norm_eps = batch_norm_eps
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
|||||||
@@ -304,6 +304,12 @@ def load_backbone(config):
|
|||||||
use_timm_backbone = getattr(config, "use_timm_backbone", None)
|
use_timm_backbone = getattr(config, "use_timm_backbone", None)
|
||||||
use_pretrained_backbone = getattr(config, "use_pretrained_backbone", None)
|
use_pretrained_backbone = getattr(config, "use_pretrained_backbone", None)
|
||||||
backbone_checkpoint = getattr(config, "backbone", None)
|
backbone_checkpoint = getattr(config, "backbone", None)
|
||||||
|
backbone_kwargs = getattr(config, "backbone_kwargs", None)
|
||||||
|
|
||||||
|
backbone_kwargs = {} if backbone_kwargs is None else backbone_kwargs
|
||||||
|
|
||||||
|
if backbone_kwargs and backbone_config is not None:
|
||||||
|
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||||
|
|
||||||
# If there is a backbone_config and a backbone checkpoint, and use_pretrained_backbone=False then the desired
|
# If there is a backbone_config and a backbone checkpoint, and use_pretrained_backbone=False then the desired
|
||||||
# behaviour is ill-defined: do you want to load from the checkpoint's config or the backbone_config?
|
# behaviour is ill-defined: do you want to load from the checkpoint's config or the backbone_config?
|
||||||
@@ -317,7 +323,7 @@ def load_backbone(config):
|
|||||||
and backbone_checkpoint is None
|
and backbone_checkpoint is None
|
||||||
and backbone_checkpoint is None
|
and backbone_checkpoint is None
|
||||||
):
|
):
|
||||||
return AutoBackbone.from_config(config=config)
|
return AutoBackbone.from_config(config=config, **backbone_kwargs)
|
||||||
|
|
||||||
# config from the parent model that has a backbone
|
# config from the parent model that has a backbone
|
||||||
if use_timm_backbone:
|
if use_timm_backbone:
|
||||||
@@ -326,16 +332,19 @@ def load_backbone(config):
|
|||||||
# Because of how timm backbones were originally added to models, we need to pass in use_pretrained_backbone
|
# Because of how timm backbones were originally added to models, we need to pass in use_pretrained_backbone
|
||||||
# to determine whether to load the pretrained weights.
|
# to determine whether to load the pretrained weights.
|
||||||
backbone = AutoBackbone.from_pretrained(
|
backbone = AutoBackbone.from_pretrained(
|
||||||
backbone_checkpoint, use_timm_backbone=use_timm_backbone, use_pretrained_backbone=use_pretrained_backbone
|
backbone_checkpoint,
|
||||||
|
use_timm_backbone=use_timm_backbone,
|
||||||
|
use_pretrained_backbone=use_pretrained_backbone,
|
||||||
|
**backbone_kwargs,
|
||||||
)
|
)
|
||||||
elif use_pretrained_backbone:
|
elif use_pretrained_backbone:
|
||||||
if backbone_checkpoint is None:
|
if backbone_checkpoint is None:
|
||||||
raise ValueError("config.backbone must be set if use_pretrained_backbone is True")
|
raise ValueError("config.backbone must be set if use_pretrained_backbone is True")
|
||||||
backbone = AutoBackbone.from_pretrained(backbone_checkpoint)
|
backbone = AutoBackbone.from_pretrained(backbone_checkpoint, **backbone_kwargs)
|
||||||
else:
|
else:
|
||||||
if backbone_config is None and backbone_checkpoint is None:
|
if backbone_config is None and backbone_checkpoint is None:
|
||||||
raise ValueError("Either config.backbone_config or config.backbone must be set")
|
raise ValueError("Either config.backbone_config or config.backbone must be set")
|
||||||
if backbone_config is None:
|
if backbone_config is None:
|
||||||
backbone_config = AutoConfig.from_pretrained(backbone_checkpoint)
|
backbone_config = AutoConfig.from_pretrained(backbone_checkpoint, **backbone_kwargs)
|
||||||
backbone = AutoBackbone.from_config(config=backbone_config)
|
backbone = AutoBackbone.from_config(config=backbone_config)
|
||||||
return backbone
|
return backbone
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import unittest
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import DetrConfig, MaskFormerConfig
|
from transformers import DetrConfig, MaskFormerConfig, ResNetBackbone, ResNetConfig, TimmBackbone
|
||||||
from transformers.testing_utils import require_torch, slow
|
from transformers.testing_utils import require_torch, slow
|
||||||
from transformers.utils.backbone_utils import (
|
from transformers.utils.backbone_utils import (
|
||||||
BackboneMixin,
|
BackboneMixin,
|
||||||
@@ -137,6 +137,65 @@ class BackboneUtilsTester(unittest.TestCase):
|
|||||||
self.assertEqual(backbone.out_features, ["a", "c"])
|
self.assertEqual(backbone.out_features, ["a", "c"])
|
||||||
self.assertEqual(backbone.out_indices, [-3, -1])
|
self.assertEqual(backbone.out_indices, [-3, -1])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_load_backbone_from_config(self):
|
||||||
|
"""
|
||||||
|
Test that load_backbone correctly loads a backbone from a backbone config.
|
||||||
|
"""
|
||||||
|
config = MaskFormerConfig(backbone_config=ResNetConfig(out_indices=(0, 2)))
|
||||||
|
backbone = load_backbone(config)
|
||||||
|
self.assertEqual(backbone.out_features, ["stem", "stage2"])
|
||||||
|
self.assertEqual(backbone.out_indices, (0, 2))
|
||||||
|
self.assertIsInstance(backbone, ResNetBackbone)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_load_backbone_from_checkpoint(self):
|
||||||
|
"""
|
||||||
|
Test that load_backbone correctly loads a backbone from a checkpoint.
|
||||||
|
"""
|
||||||
|
config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_config=None)
|
||||||
|
backbone = load_backbone(config)
|
||||||
|
self.assertEqual(backbone.out_indices, [4])
|
||||||
|
self.assertEqual(backbone.out_features, ["stage4"])
|
||||||
|
self.assertIsInstance(backbone, ResNetBackbone)
|
||||||
|
|
||||||
|
config = MaskFormerConfig(
|
||||||
|
backbone="resnet18",
|
||||||
|
use_timm_backbone=True,
|
||||||
|
)
|
||||||
|
backbone = load_backbone(config)
|
||||||
|
# We can't know ahead of time the exact output features and indices, or the layer names before
|
||||||
|
# creating the timm model, so it defaults to the last layer (-1,) and has a different layer name
|
||||||
|
self.assertEqual(backbone.out_indices, (-1,))
|
||||||
|
self.assertEqual(backbone.out_features, ["layer4"])
|
||||||
|
self.assertIsInstance(backbone, TimmBackbone)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_load_backbone_backbone_kwargs(self):
|
||||||
|
"""
|
||||||
|
Test that load_backbone correctly configures the loaded backbone with the provided kwargs.
|
||||||
|
"""
|
||||||
|
config = MaskFormerConfig(backbone="resnet18", use_timm_backbone=True, backbone_kwargs={"out_indices": (0, 1)})
|
||||||
|
backbone = load_backbone(config)
|
||||||
|
self.assertEqual(backbone.out_indices, (0, 1))
|
||||||
|
self.assertIsInstance(backbone, TimmBackbone)
|
||||||
|
|
||||||
|
config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_kwargs={"out_indices": (0, 2)})
|
||||||
|
backbone = load_backbone(config)
|
||||||
|
self.assertEqual(backbone.out_indices, (0, 2))
|
||||||
|
self.assertIsInstance(backbone, ResNetBackbone)
|
||||||
|
|
||||||
|
# Check can't be passed with a backone config
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
config = MaskFormerConfig(
|
||||||
|
backbone="microsoft/resnet-18",
|
||||||
|
backbone_config=ResNetConfig(out_indices=(0, 2)),
|
||||||
|
backbone_kwargs={"out_indices": (0, 1)},
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_load_backbone_in_new_model(self):
|
def test_load_backbone_in_new_model(self):
|
||||||
|
|||||||
@@ -224,6 +224,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
|
|||||||
"backbone",
|
"backbone",
|
||||||
"backbone_config",
|
"backbone_config",
|
||||||
"use_timm_backbone",
|
"use_timm_backbone",
|
||||||
|
"backbone_kwargs",
|
||||||
]
|
]
|
||||||
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
|
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user