Add TimmBackbone model (#22619)

* Add test_backbone for convnext

* Add TimmBackbone model

* Add check for backbone type

* Tidying up - config checks

* Update convnextv2

* Tidy up

* Fix indices & clearer comment

* Exceptions for config checks

* Correclty update config for tests

* Safer imports

* Safer safer imports

* Fix where decorators go

* Update import logic and backbone tests

* More import fixes

* Fixup

* Only import all_models if torch available

* Fix kwarg updates in from_pretrained & main rebase

* Tidy up

* Add tests for AutoBackbone

* Tidy up

* Fix import error

* Fix up

* Install nattan in doc_test_job

* Revert back to setting self._out_xxx directly

* Bug fix - out_indices mapping from out_features

* Fix tests

* Dont accept output_loading_info for Timm models

* Set out_xxx and don't remap

* Use smaller checkpoint for test

* Don't remap timm indices - check out_indices based on stage names

* Skip test as it's n/a

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Cleaner imports / spelling is hard

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
amyeroberts
2023-06-06 17:11:30 +01:00
committed by GitHub
parent b8935980a2
commit a717e0318c
29 changed files with 753 additions and 88 deletions

View File

@@ -17,6 +17,7 @@ import copy
import inspect
from transformers.testing_utils import require_torch, torch_device
from transformers.utils.backbone_utils import BackboneType
@require_torch
@@ -104,6 +105,8 @@ class BackboneTesterMixin:
self.assertEqual(len(result.feature_maps), len(config.out_features))
self.assertEqual(len(model.channels), len(config.out_features))
self.assertEqual(len(result.feature_maps), len(config.out_indices))
self.assertEqual(len(model.channels), len(config.out_indices))
# Check output of last stage is taken if out_features=None, out_indices=None
modified_config = copy.deepcopy(config)
@@ -140,6 +143,7 @@ class BackboneTesterMixin:
for backbone_class in self.all_model_classes:
backbone = backbone_class(config)
self.assertTrue(hasattr(backbone, "backbone_type"))
self.assertTrue(hasattr(backbone, "stage_names"))
self.assertTrue(hasattr(backbone, "num_features"))
self.assertTrue(hasattr(backbone, "out_indices"))
@@ -147,6 +151,7 @@ class BackboneTesterMixin:
self.assertTrue(hasattr(backbone, "out_feature_channels"))
self.assertTrue(hasattr(backbone, "channels"))
self.assertIsInstance(backbone.backbone_type, BackboneType)
# Verify num_features has been initialized in the backbone init
self.assertIsNotNone(backbone.num_features)
self.assertTrue(len(backbone.channels) == len(backbone.out_indices))