Timm wrapper label names (#35553)

* Add timm wrapper label names mapping

* Add index to classification pipeline

* Revert adding index for pipelines

* Add custom model check for loading timm labels

* Add tests for labels

* [run-slow] timm_wrapper

* Add note regarding label2id mapping
This commit is contained in:
Pavel Iakubovskii
2025-01-08 14:09:46 +00:00
committed by GitHub
parent f1639ea51d
commit 59e5b3f01b
2 changed files with 80 additions and 2 deletions

View File

@@ -200,6 +200,54 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
output = model(**inputs_dict, do_pooling=True)
self.assertIsNotNone(output.pooler_output)
def test_timm_config_labels(self):
# test timm config with no labels
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(checkpoint)
self.assertIsNone(config.label2id)
self.assertIsInstance(config.id2label, dict)
self.assertEqual(len(config.id2label), 1000)
self.assertEqual(config.id2label[1], "goldfish, Carassius auratus")
# test timm config with labels in config
checkpoint = "timm/eva02_large_patch14_clip_336.merged2b_ft_inat21"
config = TimmWrapperConfig.from_pretrained(checkpoint)
self.assertIsInstance(config.id2label, dict)
self.assertEqual(len(config.id2label), 10000)
self.assertEqual(config.id2label[1], "Sabella spallanzanii")
self.assertIsInstance(config.label2id, dict)
self.assertEqual(len(config.label2id), 10000)
self.assertEqual(config.label2id["Sabella spallanzanii"], 1)
# test custom labels are provided
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(checkpoint, num_labels=2)
self.assertEqual(config.num_labels, 2)
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})
# test with provided id2label and label2id
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(
checkpoint, num_labels=2, id2label={0: "LABEL_0", 1: "LABEL_1"}, label2id={"LABEL_0": 0, "LABEL_1": 1}
)
self.assertEqual(config.num_labels, 2)
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})
# test save load
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(checkpoint)
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_pretrained(tmpdirname)
restored_config = TimmWrapperConfig.from_pretrained(tmpdirname)
self.assertEqual(config.num_labels, restored_config.num_labels)
self.assertEqual(config.id2label, restored_config.id2label)
self.assertEqual(config.label2id, restored_config.label2id)
# We will verify our results on an image of cute cats
def prepare_img():