Enable instantiating model with pretrained backbone weights (#28214)

* Enable instantiating model with pretrained backbone weights

* Update tests so backbone checkpoint isn't passed in

* Remove doc updates until changes made in modeling code

* Clarify pretrained import

* Update configs - docs and validation check

* Update src/transformers/utils/backbone_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Clarify exception message

* Update config init in tests

* Add test for when use_timm_backbone=True

* Small test updates

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
amyeroberts
2024-01-23 11:01:50 +00:00
committed by GitHub
parent 008a6a2208
commit 27c79a0fb4
31 changed files with 362 additions and 37 deletions

View File

@@ -149,7 +149,9 @@ class DeformableDetrModelTester:
encoder_n_points=self.encoder_n_points,
decoder_n_points=self.decoder_n_points,
use_timm_backbone=False,
backbone=None,
backbone_config=resnet_config,
use_pretrained_backbone=False,
)
def prepare_config_and_inputs_for_common(self):
@@ -518,6 +520,8 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.use_timm_backbone = True
config.backbone_config = None
for model_class in self.all_model_classes:
model = model_class(config)