Enable instantiating model with pretrained backbone weights (#28214)

* Enable instantiating model with pretrained backbone weights

* Update tests so backbone checkpoint isn't passed in

* Remove doc updates until changes made in modeling code

* Clarify pretrained import

* Update configs - docs and validation check

* Update src/transformers/utils/backbone_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Clarify exception message

* Update config init in tests

* Add test for when use_timm_backbone=True

* Small test updates

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
amyeroberts
2024-01-23 11:01:50 +00:00
committed by GitHub
parent 008a6a2208
commit 27c79a0fb4
31 changed files with 362 additions and 37 deletions

View File

@@ -134,6 +134,8 @@ class ConditionalDetrModelTester:
num_labels=self.num_labels,
use_timm_backbone=False,
backbone_config=resnet_config,
backbone=None,
use_pretrained_backbone=False,
)
def prepare_config_and_inputs_for_common(self):

View File

@@ -149,7 +149,9 @@ class DeformableDetrModelTester:
encoder_n_points=self.encoder_n_points,
decoder_n_points=self.decoder_n_points,
use_timm_backbone=False,
backbone=None,
backbone_config=resnet_config,
use_pretrained_backbone=False,
)
def prepare_config_and_inputs_for_common(self):
@@ -518,6 +520,8 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.use_timm_backbone = True
config.backbone_config = None
for model_class in self.all_model_classes:
model = model_class(config)

View File

@@ -157,6 +157,7 @@ class DetaModelTester:
assign_first_stage=assign_first_stage,
assign_second_stage=assign_second_stage,
backbone_config=resnet_config,
backbone=None,
)
def prepare_config_and_inputs_for_common(self, model_class_name="DetaModel"):

View File

@@ -130,6 +130,8 @@ class DetrModelTester:
num_labels=self.num_labels,
use_timm_backbone=False,
backbone_config=resnet_config,
backbone=None,
use_pretrained_backbone=False,
)
def prepare_config_and_inputs_for_common(self):
@@ -622,7 +624,7 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
torch_device
)
expected_number_of_segments = 5
expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994096}
expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994097}
number_of_unique_segments = len(torch.unique(results["segmentation"]))
self.assertTrue(

View File

@@ -95,6 +95,7 @@ class DPTModelTester:
def get_config(self):
return DPTConfig(
backbone_config=self.get_backbone_config(),
backbone=None,
neck_hidden_sizes=self.neck_hidden_sizes,
fusion_hidden_size=self.fusion_hidden_size,
)

View File

@@ -130,6 +130,7 @@ class DPTModelTester:
initializer_range=self.initializer_range,
is_hybrid=self.is_hybrid,
backbone_config=backbone_config,
backbone=None,
backbone_featmap_shape=self.backbone_featmap_shape,
neck_hidden_sizes=self.neck_hidden_sizes,
)

View File

@@ -114,6 +114,7 @@ class Mask2FormerModelTester:
config.backbone_config.hidden_size = 16
config.backbone_config.num_channels = self.num_channels
config.backbone_config.num_heads = [1, 1, 2, 2]
config.backbone = None
config.hidden_dim = self.hidden_dim
config.mask_feature_size = self.hidden_dim

View File

@@ -102,6 +102,7 @@ class MaskFormerModelTester:
hidden_size=32,
num_heads=[1, 1, 2, 2],
),
backbone=None,
decoder_config=DetrConfig(
decoder_ffn_dim=64,
decoder_layers=self.num_hidden_layers,

View File

@@ -133,6 +133,7 @@ class OneFormerModelTester:
config.backbone_config.hidden_size = 16
config.backbone_config.num_channels = self.num_channels
config.backbone_config.num_heads = [1, 1, 2, 2]
config.backbone = None
config.hidden_dim = self.hidden_dim
config.mask_dim = self.hidden_dim

View File

@@ -131,6 +131,8 @@ class TableTransformerModelTester:
num_labels=self.num_labels,
use_timm_backbone=False,
backbone_config=resnet_config,
backbone=None,
use_pretrained_backbone=False,
)
def prepare_config_and_inputs_for_common(self):

View File

@@ -124,6 +124,7 @@ class TVPModelTester:
)
return TvpConfig(
backbone_config=resnet_config,
backbone=None,
alpha=self.alpha,
beta=self.beta,
visual_prompter_type=self.visual_prompter_type,

View File

@@ -105,6 +105,7 @@ class UperNetModelTester:
def get_config(self):
return UperNetConfig(
backbone_config=self.get_backbone_config(),
backbone=None,
hidden_size=64,
pool_scales=[1, 2, 3, 6],
use_auxiliary_head=True,

View File

@@ -122,6 +122,7 @@ class ViTHybridModelTester:
initializer_range=self.initializer_range,
backbone_featmap_shape=self.backbone_featmap_shape,
backbone_config=backbone_config,
backbone=None,
)
def create_and_check_model(self, config, pixel_values, labels):

View File

@@ -111,6 +111,7 @@ class VitMatteModelTester:
def get_config(self):
return VitMatteConfig(
backbone_config=self.get_backbone_config(),
backbone=None,
hidden_size=self.hidden_size,
fusion_hidden_sizes=self.fusion_hidden_sizes,
)