Add kwargs for timm.create_model in TimmWrapper (#38860)

* Add init kwargs for timm wrapper

* model_init_kwargs -> model_args

* add save-load test

* fixup
This commit is contained in:
Pavel Iakubovskii
2025-06-20 13:00:09 +01:00
committed by GitHub
parent ff95974bc6
commit 9120567b02
3 changed files with 36 additions and 4 deletions

View File

@@ -237,6 +237,24 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
self.assertEqual(config.id2label, restored_config.id2label)
self.assertEqual(config.label2id, restored_config.label2id)
def test_model_init_args(self):
# test init from config
config = TimmWrapperConfig.from_pretrained(
"timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
model_args={"depth": 3},
)
model = TimmWrapperModel(config)
self.assertEqual(len(model.timm_model.blocks), 3)
cls_model = TimmWrapperForImageClassification(config)
self.assertEqual(len(cls_model.timm_model.blocks), 3)
# test save load
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
restored_model = TimmWrapperModel.from_pretrained(tmpdirname)
self.assertEqual(len(restored_model.timm_model.blocks), 3)
# We will verify our results on an image of cute cats
def prepare_img():