Backbone kwargs in config (#28784)

* Enable instantiating model with pretrained backbone weights

* Clarify pretrained import

* Use load_backbone instead

* Add backbone_kwargs to config

* Pass kwargs to constructors

* Fix up

* Input verification

* Add tests

* Tidy up

* Update tests/utils/test_backbone_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
amyeroberts
2024-02-14 20:46:44 +00:00
committed by GitHub
parent 725f4ad1cc
commit 0199a484eb
16 changed files with 181 additions and 8 deletions

View File

@@ -16,7 +16,7 @@ import unittest
import pytest
from transformers import DetrConfig, MaskFormerConfig
from transformers import DetrConfig, MaskFormerConfig, ResNetBackbone, ResNetConfig, TimmBackbone
from transformers.testing_utils import require_torch, slow
from transformers.utils.backbone_utils import (
BackboneMixin,
@@ -137,6 +137,65 @@ class BackboneUtilsTester(unittest.TestCase):
self.assertEqual(backbone.out_features, ["a", "c"])
self.assertEqual(backbone.out_indices, [-3, -1])
@slow
@require_torch
def test_load_backbone_from_config(self):
"""
Test that load_backbone correctly loads a backbone from a backbone config.
"""
config = MaskFormerConfig(backbone_config=ResNetConfig(out_indices=(0, 2)))
backbone = load_backbone(config)
self.assertEqual(backbone.out_features, ["stem", "stage2"])
self.assertEqual(backbone.out_indices, (0, 2))
self.assertIsInstance(backbone, ResNetBackbone)
@slow
@require_torch
def test_load_backbone_from_checkpoint(self):
"""
Test that load_backbone correctly loads a backbone from a checkpoint.
"""
config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_config=None)
backbone = load_backbone(config)
self.assertEqual(backbone.out_indices, [4])
self.assertEqual(backbone.out_features, ["stage4"])
self.assertIsInstance(backbone, ResNetBackbone)
config = MaskFormerConfig(
backbone="resnet18",
use_timm_backbone=True,
)
backbone = load_backbone(config)
# We can't know ahead of time the exact output features and indices, or the layer names before
# creating the timm model, so it defaults to the last layer (-1,) and has a different layer name
self.assertEqual(backbone.out_indices, (-1,))
self.assertEqual(backbone.out_features, ["layer4"])
self.assertIsInstance(backbone, TimmBackbone)
@slow
@require_torch
def test_load_backbone_backbone_kwargs(self):
"""
Test that load_backbone correctly configures the loaded backbone with the provided kwargs.
"""
config = MaskFormerConfig(backbone="resnet18", use_timm_backbone=True, backbone_kwargs={"out_indices": (0, 1)})
backbone = load_backbone(config)
self.assertEqual(backbone.out_indices, (0, 1))
self.assertIsInstance(backbone, TimmBackbone)
config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_kwargs={"out_indices": (0, 2)})
backbone = load_backbone(config)
self.assertEqual(backbone.out_indices, (0, 2))
self.assertIsInstance(backbone, ResNetBackbone)
# Check can't be passed with a backone config
with pytest.raises(ValueError):
config = MaskFormerConfig(
backbone="microsoft/resnet-18",
backbone_config=ResNetConfig(out_indices=(0, 2)),
backbone_kwargs={"out_indices": (0, 1)},
)
@slow
@require_torch
def test_load_backbone_in_new_model(self):