[DETR] Remove timm hardcoded logic in modeling files (#29038)
* Enable instantiating model with pretrained backbone weights * Clarify pretrained import * Use load_backbone instead * Add backbone_kwargs to config * Fix up * Add tests * Tidy up * Enable instantiating model with pretrained backbone weights * Update tests so backbone checkpoint isn't passed in * Clarify pretrained import * Update configs - docs and validation check * Update src/transformers/utils/backbone_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clarify exception message * Update config init in tests * Add test for when use_timm_backbone=True * Use load_backbone instead * Add use_timm_backbone to the model configs * Add backbone_kwargs to config * Pass kwargs to constructors * Draft * Fix tests * Add back timm - weight naming * More tidying up * Whoops * Tidy up * Handle when kwargs are none * Update tests * Revert test changes * Deformable detr test - don't use default * Don't mutate; correct model attributes * Add some clarifying comments * nit - grammar is hard --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -444,6 +444,9 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
|
||||
# let's pick a random timm backbone
|
||||
config.backbone = "tf_mobilenetv3_small_075"
|
||||
config.backbone_config = None
|
||||
config.use_timm_backbone = True
|
||||
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
@@ -459,6 +462,14 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
self.model_tester.num_labels + 1,
|
||||
)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
elif model_class.__name__ == "DetrForSegmentation":
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
else:
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user