Add kwargs for timm.create_model in TimmWrapper (#38860)
* Add init kwargs for timm wrapper * model_init_kwargs -> model_args * add save-load test * fixup
This commit is contained in:
committed by
GitHub
parent
ff95974bc6
commit
9120567b02
@@ -237,6 +237,24 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
self.assertEqual(config.id2label, restored_config.id2label)
|
||||
self.assertEqual(config.label2id, restored_config.label2id)
|
||||
|
||||
def test_model_init_args(self):
|
||||
# test init from config
|
||||
config = TimmWrapperConfig.from_pretrained(
|
||||
"timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
|
||||
model_args={"depth": 3},
|
||||
)
|
||||
model = TimmWrapperModel(config)
|
||||
self.assertEqual(len(model.timm_model.blocks), 3)
|
||||
|
||||
cls_model = TimmWrapperForImageClassification(config)
|
||||
self.assertEqual(len(cls_model.timm_model.blocks), 3)
|
||||
|
||||
# test save load
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
restored_model = TimmWrapperModel.from_pretrained(tmpdirname)
|
||||
self.assertEqual(len(restored_model.timm_model.blocks), 3)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
||||
Reference in New Issue
Block a user