[Backbone] Use load_backbone instead of AutoBackbone.from_config (#28661)

* Enable instantiating model with pretrained backbone weights

* Remove doc updates until changes made in modeling code

* Use load_backbone instead

* Add use_timm_backbone to the model configs

* Add missing imports and arguments

* Update docstrings

* Make sure test is properly configured

* Include recent DPT updates
This commit is contained in:
amyeroberts
2024-01-30 16:54:09 +00:00
committed by GitHub
parent c24c52454a
commit 2fa1c808ae
24 changed files with 89 additions and 44 deletions

View File

@@ -37,7 +37,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_conditional_detr import ConditionalDetrConfig from .configuration_conditional_detr import ConditionalDetrConfig
@@ -363,7 +363,7 @@ class ConditionalDetrConvEncoder(nn.Module):
**kwargs, **kwargs,
) )
else: else:
backbone = AutoBackbone.from_config(config.backbone_config) backbone = load_backbone(config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():

View File

@@ -44,7 +44,7 @@ from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, logging from ...utils import is_ninja_available, logging
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels from .load_custom import load_cuda_kernels
@@ -409,7 +409,7 @@ class DeformableDetrConvEncoder(nn.Module):
**kwargs, **kwargs,
) )
else: else:
backbone = AutoBackbone.from_config(config.backbone_config) backbone = load_backbone(config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():

View File

@@ -46,6 +46,9 @@ class DetaConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`): use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
num_queries (`int`, *optional*, defaults to 900): num_queries (`int`, *optional*, defaults to 900):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead. detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
@@ -146,6 +149,7 @@ class DetaConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
num_queries=900, num_queries=900,
max_position_embeddings=2048, max_position_embeddings=2048,
encoder_layers=6, encoder_layers=6,
@@ -203,6 +207,7 @@ class DetaConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.num_queries = num_queries self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.d_model = d_model self.d_model = d_model

View File

@@ -39,7 +39,7 @@ from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid from ...pytorch_utils import meshgrid
from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_deta import DetaConfig from .configuration_deta import DetaConfig
@@ -338,7 +338,7 @@ class DetaBackboneWithPositionalEncodings(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
backbone = AutoBackbone.from_config(config.backbone_config) backbone = load_backbone(config)
with torch.no_grad(): with torch.no_grad():
replace_batch_norm(backbone) replace_batch_norm(backbone)
self.model = backbone self.model = backbone

View File

@@ -37,7 +37,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_detr import DetrConfig from .configuration_detr import DetrConfig
@@ -356,7 +356,7 @@ class DetrConvEncoder(nn.Module):
**kwargs, **kwargs,
) )
else: else:
backbone = AutoBackbone.from_config(config.backbone_config) backbone = load_backbone(config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():

View File

@@ -117,6 +117,9 @@ class DPTConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
Example: Example:
@@ -169,6 +172,7 @@ class DPTConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -179,9 +183,6 @@ class DPTConfig(PretrainedConfig):
if use_pretrained_backbone: if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.") raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
use_autobackbone = False use_autobackbone = False
if self.is_hybrid: if self.is_hybrid:
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
@@ -193,17 +194,17 @@ class DPTConfig(PretrainedConfig):
"out_features": ["stage1", "stage2", "stage3"], "out_features": ["stage1", "stage2", "stage3"],
"embedding_dynamic_padding": True, "embedding_dynamic_padding": True,
} }
self.backbone_config = BitConfig(**backbone_config) backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, dict): elif isinstance(backbone_config, dict):
logger.info("Initializing the config with a `BiT` backbone.") logger.info("Initializing the config with a `BiT` backbone.")
self.backbone_config = BitConfig(**backbone_config) backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, PretrainedConfig): elif isinstance(backbone_config, PretrainedConfig):
self.backbone_config = backbone_config backbone_config = backbone_config
else: else:
raise ValueError( raise ValueError(
f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}." f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}."
) )
self.backbone_config = backbone_config
self.backbone_featmap_shape = backbone_featmap_shape self.backbone_featmap_shape = backbone_featmap_shape
self.neck_ignore_stages = neck_ignore_stages self.neck_ignore_stages = neck_ignore_stages
@@ -221,14 +222,17 @@ class DPTConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone_featmap_shape = None self.backbone_featmap_shape = None
self.neck_ignore_stages = [] self.neck_ignore_stages = []
else: else:
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone_featmap_shape = None self.backbone_featmap_shape = None
self.neck_ignore_stages = [] self.neck_ignore_stages = []
if use_autobackbone and backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
self.num_attention_heads = None if use_autobackbone else num_attention_heads self.num_attention_heads = None if use_autobackbone else num_attention_heads
self.intermediate_size = None if use_autobackbone else intermediate_size self.intermediate_size = None if use_autobackbone else intermediate_size

View File

@@ -41,7 +41,7 @@ from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, logging from ...utils import ModelOutput, logging
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_dpt import DPTConfig from .configuration_dpt import DPTConfig
@@ -131,12 +131,10 @@ class DPTViTHybridEmbeddings(nn.Module):
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
feature_dim = self.backbone.channels[-1] feature_dim = self.backbone.channels[-1]
if len(config.backbone_config.out_features) != 3: if len(self.backbone.channels) != 3:
raise ValueError( raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}")
f"Expected backbone to have 3 output features, got {len(config.backbone_config.out_features)}"
)
self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage
if feature_size is None: if feature_size is None:
@@ -1082,7 +1080,7 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
self.backbone = None self.backbone = None
if config.backbone_config is not None and config.is_hybrid is False: if config.backbone_config is not None and config.is_hybrid is False:
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
else: else:
self.dpt = DPTModel(config, add_pooling_layer=False) self.dpt = DPTModel(config, add_pooling_layer=False)

View File

@@ -53,6 +53,9 @@ class Mask2FormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`): use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
feature_size (`int`, *optional*, defaults to 256): feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the resulting feature maps. The features (channels) of the resulting feature maps.
mask_feature_size (`int`, *optional*, defaults to 256): mask_feature_size (`int`, *optional*, defaults to 256):
@@ -162,6 +165,7 @@ class Mask2FormerConfig(PretrainedConfig):
output_auxiliary_logits: bool = None, output_auxiliary_logits: bool = None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
**kwargs, **kwargs,
): ):
if use_pretrained_backbone: if use_pretrained_backbone:
@@ -228,6 +232,7 @@ class Mask2FormerConfig(PretrainedConfig):
self.num_hidden_layers = decoder_layers self.num_hidden_layers = decoder_layers
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@@ -23,7 +23,6 @@ import numpy as np
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from ... import AutoBackbone
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
ModelOutput, ModelOutput,
@@ -36,6 +35,7 @@ from ...file_utils import (
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import load_backbone
from .configuration_mask2former import Mask2FormerConfig from .configuration_mask2former import Mask2FormerConfig
@@ -1376,7 +1376,7 @@ class Mask2FormerPixelLevelModule(nn.Module):
""" """
super().__init__() super().__init__()
self.encoder = AutoBackbone.from_config(config.backbone_config) self.encoder = load_backbone(config)
self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels) self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput: def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput:

View File

@@ -63,6 +63,9 @@ class MaskFormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`): use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
decoder_config (`Dict`, *optional*): decoder_config (`Dict`, *optional*):
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50` The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
will be used. will be used.
@@ -122,6 +125,7 @@ class MaskFormerConfig(PretrainedConfig):
output_auxiliary_logits: Optional[bool] = None, output_auxiliary_logits: Optional[bool] = None,
backbone: Optional[str] = None, backbone: Optional[str] = None,
use_pretrained_backbone: bool = False, use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
**kwargs, **kwargs,
): ):
if use_pretrained_backbone: if use_pretrained_backbone:
@@ -193,6 +197,7 @@ class MaskFormerConfig(PretrainedConfig):
self.num_hidden_layers = self.decoder_config.num_hidden_layers self.num_hidden_layers = self.decoder_config.num_hidden_layers
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
super().__init__(**kwargs) super().__init__(**kwargs)
@classmethod @classmethod

View File

@@ -23,7 +23,6 @@ import numpy as np
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from ... import AutoBackbone
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithCrossAttentions
@@ -37,6 +36,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ...utils.backbone_utils import load_backbone
from ..detr import DetrConfig from ..detr import DetrConfig
from .configuration_maskformer import MaskFormerConfig from .configuration_maskformer import MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig from .configuration_maskformer_swin import MaskFormerSwinConfig
@@ -1428,14 +1428,13 @@ class MaskFormerPixelLevelModule(nn.Module):
The configuration used to instantiate this model. The configuration used to instantiate this model.
""" """
super().__init__() super().__init__()
if hasattr(config, "backbone_config") and config.backbone_config.model_type == "swin":
# TODD: add method to load pretrained weights of backbone
backbone_config = config.backbone_config
if backbone_config.model_type == "swin":
# for backwards compatibility # for backwards compatibility
backbone_config = config.backbone_config
backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict()) backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict())
backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"] backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"]
self.encoder = AutoBackbone.from_config(backbone_config) config.backbone_config = backbone_config
self.encoder = load_backbone(config)
feature_channels = self.encoder.channels feature_channels = self.encoder.channels
self.decoder = MaskFormerPixelDecoder( self.decoder = MaskFormerPixelDecoder(

View File

@@ -50,6 +50,9 @@ class OneFormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
ignore_value (`int`, *optional*, defaults to 255): ignore_value (`int`, *optional*, defaults to 255):
Values to be ignored in GT label while calculating loss. Values to be ignored in GT label while calculating loss.
num_queries (`int`, *optional*, defaults to 150): num_queries (`int`, *optional*, defaults to 150):
@@ -152,6 +155,7 @@ class OneFormerConfig(PretrainedConfig):
backbone_config: Optional[Dict] = None, backbone_config: Optional[Dict] = None,
backbone: Optional[str] = None, backbone: Optional[str] = None,
use_pretrained_backbone: bool = False, use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
ignore_value: int = 255, ignore_value: int = 255,
num_queries: int = 150, num_queries: int = 150,
no_object_weight: int = 0.1, no_object_weight: int = 0.1,
@@ -222,6 +226,7 @@ class OneFormerConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.ignore_value = ignore_value self.ignore_value = ignore_value
self.num_queries = num_queries self.num_queries = num_queries
self.no_object_weight = no_object_weight self.no_object_weight = no_object_weight

View File

@@ -24,7 +24,6 @@ import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from ... import AutoBackbone
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
@@ -37,6 +36,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ...utils.backbone_utils import load_backbone
from .configuration_oneformer import OneFormerConfig from .configuration_oneformer import OneFormerConfig
@@ -1478,8 +1478,7 @@ class OneFormerPixelLevelModule(nn.Module):
The configuration used to instantiate this model. The configuration used to instantiate this model.
""" """
super().__init__() super().__init__()
backbone_config = config.backbone_config self.encoder = load_backbone(config)
self.encoder = AutoBackbone.from_config(backbone_config)
self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels) self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput: def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput:

View File

@@ -37,7 +37,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_table_transformer import TableTransformerConfig from .configuration_table_transformer import TableTransformerConfig
@@ -290,7 +290,7 @@ class TableTransformerConvEncoder(nn.Module):
**kwargs, **kwargs,
) )
else: else:
backbone = AutoBackbone.from_config(config.backbone_config) backbone = load_backbone(config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():

View File

@@ -49,6 +49,9 @@ class TvpConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
distance_loss_weight (`float`, *optional*, defaults to 1.0): distance_loss_weight (`float`, *optional*, defaults to 1.0):
The weight of distance loss. The weight of distance loss.
duration_loss_weight (`float`, *optional*, defaults to 0.1): duration_loss_weight (`float`, *optional*, defaults to 0.1):
@@ -103,6 +106,7 @@ class TvpConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
distance_loss_weight=1.0, distance_loss_weight=1.0,
duration_loss_weight=0.1, duration_loss_weight=0.1,
visual_prompter_type="framepad", visual_prompter_type="framepad",
@@ -143,6 +147,7 @@ class TvpConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.distance_loss_weight = distance_loss_weight self.distance_loss_weight = distance_loss_weight
self.duration_loss_weight = duration_loss_weight self.duration_loss_weight = duration_loss_weight
self.visual_prompter_type = visual_prompter_type self.visual_prompter_type = visual_prompter_type

View File

@@ -28,7 +28,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, Mod
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import prune_linear_layer from ...pytorch_utils import prune_linear_layer
from ...utils import logging from ...utils import logging
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_tvp import TvpConfig from .configuration_tvp import TvpConfig
@@ -148,7 +148,7 @@ class TvpLoss(nn.Module):
class TvpVisionModel(nn.Module): class TvpVisionModel(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
self.grid_encoder_conv = nn.Conv2d( self.grid_encoder_conv = nn.Conv2d(
config.backbone_config.hidden_sizes[-1], config.backbone_config.hidden_sizes[-1],
config.hidden_size, config.hidden_size,

View File

@@ -42,6 +42,9 @@ class UperNetConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`): use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
hidden_size (`int`, *optional*, defaults to 512): hidden_size (`int`, *optional*, defaults to 512):
The number of hidden units in the convolutional layers. The number of hidden units in the convolutional layers.
initializer_range (`float`, *optional*, defaults to 0.02): initializer_range (`float`, *optional*, defaults to 0.02):
@@ -83,6 +86,7 @@ class UperNetConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
hidden_size=512, hidden_size=512,
initializer_range=0.02, initializer_range=0.02,
pool_scales=[1, 2, 3, 6], pool_scales=[1, 2, 3, 6],
@@ -113,6 +117,7 @@ class UperNetConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.pool_scales = pool_scales self.pool_scales = pool_scales

View File

@@ -20,10 +20,10 @@ import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ... import AutoBackbone
from ...modeling_outputs import SemanticSegmenterOutput from ...modeling_outputs import SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...utils.backbone_utils import load_backbone
from .configuration_upernet import UperNetConfig from .configuration_upernet import UperNetConfig
@@ -348,7 +348,7 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
# Semantic segmentation head(s) # Semantic segmentation head(s)
self.decode_head = UperNetHead(config, in_channels=self.backbone.channels) self.decode_head = UperNetHead(config, in_channels=self.backbone.channels)

View File

@@ -48,6 +48,9 @@ class ViTHybridConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
hidden_size (`int`, *optional*, defaults to 768): hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer. Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12): num_hidden_layers (`int`, *optional*, defaults to 12):
@@ -100,6 +103,7 @@ class ViTHybridConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
hidden_size=768, hidden_size=768,
num_hidden_layers=12, num_hidden_layers=12,
num_attention_heads=12, num_attention_heads=12,
@@ -147,6 +151,7 @@ class ViTHybridConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads

View File

@@ -29,7 +29,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, Ima
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_vit_hybrid import ViTHybridConfig from .configuration_vit_hybrid import ViTHybridConfig
@@ -150,7 +150,7 @@ class ViTHybridPatchEmbeddings(nn.Module):
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
if self.backbone.config.model_type != "bit": if self.backbone.config.model_type != "bit":
raise ValueError(f"Backbone model type {self.backbone.model_type} is not supported.") raise ValueError(f"Backbone model type {self.backbone.model_type} is not supported.")
feature_dim = self.backbone.channels[-1] feature_dim = self.backbone.channels[-1]

View File

@@ -48,6 +48,9 @@ class VitMatteConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
hidden_size (`int`, *optional*, defaults to 384): hidden_size (`int`, *optional*, defaults to 384):
The number of input channels of the decoder. The number of input channels of the decoder.
batch_norm_eps (`float`, *optional*, defaults to 1e-05): batch_norm_eps (`float`, *optional*, defaults to 1e-05):
@@ -81,6 +84,7 @@ class VitMatteConfig(PretrainedConfig):
backbone_config: PretrainedConfig = None, backbone_config: PretrainedConfig = None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
hidden_size: int = 384, hidden_size: int = 384,
batch_norm_eps: float = 1e-5, batch_norm_eps: float = 1e-5,
initializer_range: float = 0.02, initializer_range: float = 0.02,
@@ -107,6 +111,7 @@ class VitMatteConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.batch_norm_eps = batch_norm_eps self.batch_norm_eps = batch_norm_eps
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.initializer_range = initializer_range self.initializer_range = initializer_range

View File

@@ -20,7 +20,6 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from ... import AutoBackbone
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
@@ -28,6 +27,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.backbone_utils import load_backbone
from .configuration_vitmatte import VitMatteConfig from .configuration_vitmatte import VitMatteConfig
@@ -259,7 +259,7 @@ class VitMatteForImageMatting(VitMattePreTrainedModel):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
self.decoder = VitMatteDetailCaptureModule(config) self.decoder = VitMatteDetailCaptureModule(config)
# Initialize weights and apply final processing # Initialize weights and apply final processing

View File

@@ -443,6 +443,7 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
# let's pick a random timm backbone # let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075" config.backbone = "tf_mobilenetv3_small_075"
config.use_timm_backbone = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)

View File

@@ -219,7 +219,11 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"out_features", "out_features",
"out_indices", "out_indices",
"sampling_rate", "sampling_rate",
# backbone related arguments passed to load_backbone
"use_pretrained_backbone", "use_pretrained_backbone",
"backbone",
"backbone_config",
"use_timm_backbone",
] ]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"] attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]