[DETR] Remove timm hardcoded logic in modeling files (#29038)
* Enable instantiating model with pretrained backbone weights * Clarify pretrained import * Use load_backbone instead * Add backbone_kwargs to config * Fix up * Add tests * Tidy up * Enable instantiating model with pretrained backbone weights * Update tests so backbone checkpoint isn't passed in * Clarify pretrained import * Update configs - docs and validation check * Update src/transformers/utils/backbone_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clarify exception message * Update config init in tests * Add test for when use_timm_backbone=True * Use load_backbone instead * Add use_timm_backbone to the model configs * Add backbone_kwargs to config * Pass kwargs to constructors * Draft * Fix tests * Add back timm - weight naming * More tidying up * Whoops * Tidy up * Handle when kwargs are none * Update tests * Revert test changes * Deformable detr test - don't use default * Don't mutate; correct model attributes * Add some clarifying comments * nit - grammar is hard --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -444,7 +444,9 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
|
||||
# let's pick a random timm backbone
|
||||
config.backbone = "tf_mobilenetv3_small_075"
|
||||
config.backbone_config = None
|
||||
config.use_timm_backbone = True
|
||||
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
@@ -460,6 +462,14 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
self.model_tester.num_labels,
|
||||
)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
elif model_class.__name__ == "ConditionalDetrForSegmentation":
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
else:
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
|
||||
@@ -521,8 +521,9 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
|
||||
# let's pick a random timm backbone
|
||||
config.backbone = "tf_mobilenetv3_small_075"
|
||||
config.use_timm_backbone = True
|
||||
config.backbone_config = None
|
||||
config.use_timm_backbone = True
|
||||
config.backbone_kwargs = {"out_indices": [1, 2, 3, 4]}
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
@@ -538,6 +539,14 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
self.model_tester.num_labels,
|
||||
)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 4)
|
||||
elif model_class.__name__ == "ConditionalDetrForSegmentation":
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.deformable_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 4)
|
||||
else:
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 4)
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
|
||||
@@ -444,6 +444,9 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
|
||||
# let's pick a random timm backbone
|
||||
config.backbone = "tf_mobilenetv3_small_075"
|
||||
config.backbone_config = None
|
||||
config.use_timm_backbone = True
|
||||
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
@@ -459,6 +462,14 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
self.model_tester.num_labels + 1,
|
||||
)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
elif model_class.__name__ == "DetrForSegmentation":
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
else:
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
|
||||
@@ -456,6 +456,9 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
|
||||
|
||||
# let's pick a random timm backbone
|
||||
config.backbone = "tf_mobilenetv3_small_075"
|
||||
config.backbone_config = None
|
||||
config.use_timm_backbone = True
|
||||
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
@@ -471,6 +474,11 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
|
||||
self.model_tester.num_labels + 1,
|
||||
)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
else:
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user