Enable HF pretrained backbones (#31145)

* Enable load HF or tim backbone checkpoints

* Fix up

* Fix test - pass in proper out_indices

* Update docs

* Fix tvp tests

* Fix doc examples

* Fix doc examples

* Try to resolve DPT backbone param init

* Don't conditionally set to None

* Add condition based on whether backbone is defined

* Address review comments
This commit is contained in:
amyeroberts
2024-06-06 22:02:38 +01:00
committed by GitHub
parent a3d351c00f
commit bdf36dcd48
27 changed files with 546 additions and 69 deletions

View File

@@ -19,7 +19,14 @@ import unittest
from huggingface_hub import hf_hub_download
from transformers import ConvNextConfig, UperNetConfig
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.testing_utils import (
require_timm,
require_torch,
require_torch_multi_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -240,6 +247,33 @@ class UperNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@require_timm
def test_backbone_selection(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
config.backbone_config = None
config.backbone_kwargs = {"out_indices": [1, 2, 3]}
config.use_pretrained_backbone = True
# Load a timm backbone
# We can't load transformer checkpoint with timm backbone, as we can't specify features_only and out_indices
config.backbone = "resnet18"
config.use_timm_backbone = True
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device).eval()
if model.__class__.__name__ == "UperNetForUniversalSegmentation":
self.assertEqual(model.backbone.out_indices, [1, 2, 3])
# Load a HF backbone
config.backbone = "microsoft/resnet-18"
config.use_timm_backbone = False
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device).eval()
if model.__class__.__name__ == "UperNetForUniversalSegmentation":
self.assertEqual(model.backbone.out_indices, [1, 2, 3])
@unittest.skip(reason="UperNet does not have tied weights")
def test_tied_model_weights_key_ignore(self):
pass