Add TimmBackbone model (#22619)
* Add test_backbone for convnext * Add TimmBackbone model * Add check for backbone type * Tidying up - config checks * Update convnextv2 * Tidy up * Fix indices & clearer comment * Exceptions for config checks * Correclty update config for tests * Safer imports * Safer safer imports * Fix where decorators go * Update import logic and backbone tests * More import fixes * Fixup * Only import all_models if torch available * Fix kwarg updates in from_pretrained & main rebase * Tidy up * Add tests for AutoBackbone * Tidy up * Fix import error * Fix up * Install nattan in doc_test_job * Revert back to setting self._out_xxx directly * Bug fix - out_indices mapping from out_features * Fix tests * Dont accept output_loading_info for Timm models * Set out_xxx and don't remap * Use smaller checkpoint for test * Don't remap timm indices - check out_indices based on stage names * Skip test as it's n/a * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Cleaner imports / spelling is hard --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@ import copy
|
||||
import inspect
|
||||
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers.utils.backbone_utils import BackboneType
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -104,6 +105,8 @@ class BackboneTesterMixin:
|
||||
|
||||
self.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
self.assertEqual(len(model.channels), len(config.out_features))
|
||||
self.assertEqual(len(result.feature_maps), len(config.out_indices))
|
||||
self.assertEqual(len(model.channels), len(config.out_indices))
|
||||
|
||||
# Check output of last stage is taken if out_features=None, out_indices=None
|
||||
modified_config = copy.deepcopy(config)
|
||||
@@ -140,6 +143,7 @@ class BackboneTesterMixin:
|
||||
for backbone_class in self.all_model_classes:
|
||||
backbone = backbone_class(config)
|
||||
|
||||
self.assertTrue(hasattr(backbone, "backbone_type"))
|
||||
self.assertTrue(hasattr(backbone, "stage_names"))
|
||||
self.assertTrue(hasattr(backbone, "num_features"))
|
||||
self.assertTrue(hasattr(backbone, "out_indices"))
|
||||
@@ -147,6 +151,7 @@ class BackboneTesterMixin:
|
||||
self.assertTrue(hasattr(backbone, "out_feature_channels"))
|
||||
self.assertTrue(hasattr(backbone, "channels"))
|
||||
|
||||
self.assertIsInstance(backbone.backbone_type, BackboneType)
|
||||
# Verify num_features has been initialized in the backbone init
|
||||
self.assertIsNotNone(backbone.num_features)
|
||||
self.assertTrue(len(backbone.channels) == len(backbone.out_indices))
|
||||
|
||||
Reference in New Issue
Block a user