Enable HF pretrained backbones (#31145)

* Enable load HF or tim backbone checkpoints

* Fix up

* Fix test - pass in proper out_indices

* Update docs

* Fix tvp tests

* Fix doc examples

* Fix doc examples

* Try to resolve DPT backbone param init

* Don't conditionally set to None

* Add condition based on whether backbone is defined

* Address review comments
This commit is contained in:
amyeroberts
2024-06-06 22:02:38 +01:00
committed by GitHub
parent a3d351c00f
commit bdf36dcd48
27 changed files with 546 additions and 69 deletions

View File

@@ -207,6 +207,35 @@ class DepthAnythingModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
model = DepthAnythingForDepthEstimation.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_backbone_selection(self):
def _validate_backbone_init():
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
# Confirm out_indices propogated to backbone
self.assertEqual(len(model.backbone.out_indices), 2)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Load a timm backbone
config.backbone = "resnet18"
config.use_pretrained_backbone = True
config.use_timm_backbone = True
config.backbone_config = None
# For transformer backbones we can't set the out_indices or just return the features
config.backbone_kwargs = {"out_indices": (-2, -1)}
_validate_backbone_init()
# Load a HF backbone
config.backbone = "facebook/dinov2-small"
config.use_pretrained_backbone = True
config.use_timm_backbone = False
config.backbone_config = None
config.backbone_kwargs = {"out_indices": [-2, -1]}
_validate_backbone_init()
# We will verify our results on an image of cute cats
def prepare_img():