Backbone kwargs in config (#28784)
* Enable instantiating model with pretrained backbone weights * Clarify pretrained import * Use load_backbone instead * Add backbone_kwargs to config * Pass kwargs to constructors * Fix up * Input verification * Add tests * Tidy up * Update tests/utils/test_backbone_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import DetrConfig, MaskFormerConfig
|
||||
from transformers import DetrConfig, MaskFormerConfig, ResNetBackbone, ResNetConfig, TimmBackbone
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
from transformers.utils.backbone_utils import (
|
||||
BackboneMixin,
|
||||
@@ -137,6 +137,65 @@ class BackboneUtilsTester(unittest.TestCase):
|
||||
self.assertEqual(backbone.out_features, ["a", "c"])
|
||||
self.assertEqual(backbone.out_indices, [-3, -1])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_load_backbone_from_config(self):
|
||||
"""
|
||||
Test that load_backbone correctly loads a backbone from a backbone config.
|
||||
"""
|
||||
config = MaskFormerConfig(backbone_config=ResNetConfig(out_indices=(0, 2)))
|
||||
backbone = load_backbone(config)
|
||||
self.assertEqual(backbone.out_features, ["stem", "stage2"])
|
||||
self.assertEqual(backbone.out_indices, (0, 2))
|
||||
self.assertIsInstance(backbone, ResNetBackbone)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_load_backbone_from_checkpoint(self):
|
||||
"""
|
||||
Test that load_backbone correctly loads a backbone from a checkpoint.
|
||||
"""
|
||||
config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_config=None)
|
||||
backbone = load_backbone(config)
|
||||
self.assertEqual(backbone.out_indices, [4])
|
||||
self.assertEqual(backbone.out_features, ["stage4"])
|
||||
self.assertIsInstance(backbone, ResNetBackbone)
|
||||
|
||||
config = MaskFormerConfig(
|
||||
backbone="resnet18",
|
||||
use_timm_backbone=True,
|
||||
)
|
||||
backbone = load_backbone(config)
|
||||
# We can't know ahead of time the exact output features and indices, or the layer names before
|
||||
# creating the timm model, so it defaults to the last layer (-1,) and has a different layer name
|
||||
self.assertEqual(backbone.out_indices, (-1,))
|
||||
self.assertEqual(backbone.out_features, ["layer4"])
|
||||
self.assertIsInstance(backbone, TimmBackbone)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_load_backbone_backbone_kwargs(self):
|
||||
"""
|
||||
Test that load_backbone correctly configures the loaded backbone with the provided kwargs.
|
||||
"""
|
||||
config = MaskFormerConfig(backbone="resnet18", use_timm_backbone=True, backbone_kwargs={"out_indices": (0, 1)})
|
||||
backbone = load_backbone(config)
|
||||
self.assertEqual(backbone.out_indices, (0, 1))
|
||||
self.assertIsInstance(backbone, TimmBackbone)
|
||||
|
||||
config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_kwargs={"out_indices": (0, 2)})
|
||||
backbone = load_backbone(config)
|
||||
self.assertEqual(backbone.out_indices, (0, 2))
|
||||
self.assertIsInstance(backbone, ResNetBackbone)
|
||||
|
||||
# Check can't be passed with a backone config
|
||||
with pytest.raises(ValueError):
|
||||
config = MaskFormerConfig(
|
||||
backbone="microsoft/resnet-18",
|
||||
backbone_config=ResNetConfig(out_indices=(0, 2)),
|
||||
backbone_kwargs={"out_indices": (0, 1)},
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_load_backbone_in_new_model(self):
|
||||
|
||||
Reference in New Issue
Block a user