Add TimmBackbone model (#22619)
* Add test_backbone for convnext * Add TimmBackbone model * Add check for backbone type * Tidying up - config checks * Update convnextv2 * Tidy up * Fix indices & clearer comment * Exceptions for config checks * Correclty update config for tests * Safer imports * Safer safer imports * Fix where decorators go * Update import logic and backbone tests * More import fixes * Fixup * Only import all_models if torch available * Fix kwarg updates in from_pretrained & main rebase * Tidy up * Add tests for AutoBackbone * Tidy up * Fix import error * Fix up * Install nattan in doc_test_job * Revert back to setting self._out_xxx directly * Bug fix - out_indices mapping from out_features * Fix tests * Dont accept output_loading_info for Timm models * Set out_xxx and don't remap * Use smaller checkpoint for test * Don't remap timm indices - check out_indices based on stage names * Skip test as it's n/a * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Cleaner imports / spelling is hard --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -45,6 +45,7 @@ if is_torch_available():
|
||||
from test_module.custom_modeling import CustomModel
|
||||
|
||||
from transformers import (
|
||||
AutoBackbone,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
@@ -66,11 +67,13 @@ if is_torch_available():
|
||||
FunnelModel,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
ResNetBackbone,
|
||||
RobertaForMaskedLM,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
TapasConfig,
|
||||
TapasForQuestionAnswering,
|
||||
TimmBackbone,
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
@@ -224,6 +227,42 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForTokenClassification)
|
||||
|
||||
@slow
|
||||
def test_auto_backbone_timm_model_from_pretrained(self):
|
||||
# Configs can't be loaded for timm models
|
||||
model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# We can't pass output_loading_info=True as we're loading from timm
|
||||
AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, output_loading_info=True)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TimmBackbone)
|
||||
|
||||
# Check kwargs are correctly passed to the backbone
|
||||
model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_indices=(-1, -2))
|
||||
self.assertEqual(model.out_indices, (-1, -2))
|
||||
|
||||
# Check out_features cannot be passed to Timm backbones
|
||||
with self.assertRaises(ValueError):
|
||||
_ = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_features=["stage1"])
|
||||
|
||||
@slow
|
||||
def test_auto_backbone_from_pretrained(self):
|
||||
model = AutoBackbone.from_pretrained("microsoft/resnet-18")
|
||||
model, loading_info = AutoBackbone.from_pretrained("microsoft/resnet-18", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, ResNetBackbone)
|
||||
|
||||
# Check kwargs are correctly passed to the backbone
|
||||
model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_indices=[-1, -2])
|
||||
self.assertEqual(model.out_indices, [-1, -2])
|
||||
self.assertEqual(model.out_features, ["stage4", "stage3"])
|
||||
|
||||
model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_features=["stage2", "stage4"])
|
||||
self.assertEqual(model.out_indices, [2, 4])
|
||||
self.assertEqual(model.out_features, ["stage2", "stage4"])
|
||||
|
||||
def test_from_pretrained_identifier(self):
|
||||
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
|
||||
Reference in New Issue
Block a user