Timm wrapper label names (#35553)
* Add timm wrapper label names mapping * Add index to classification pipeline * Revert adding index for pipelines * Add custom model check for loading timm labels * Add tests for labels * [run-slow] timm_wrapper * Add note regarding label2id mapping
This commit is contained in:
committed by
GitHub
parent
f1639ea51d
commit
59e5b3f01b
@@ -200,6 +200,54 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
output = model(**inputs_dict, do_pooling=True)
|
||||
self.assertIsNotNone(output.pooler_output)
|
||||
|
||||
def test_timm_config_labels(self):
|
||||
# test timm config with no labels
|
||||
checkpoint = "timm/resnet18.a1_in1k"
|
||||
config = TimmWrapperConfig.from_pretrained(checkpoint)
|
||||
self.assertIsNone(config.label2id)
|
||||
self.assertIsInstance(config.id2label, dict)
|
||||
self.assertEqual(len(config.id2label), 1000)
|
||||
self.assertEqual(config.id2label[1], "goldfish, Carassius auratus")
|
||||
|
||||
# test timm config with labels in config
|
||||
checkpoint = "timm/eva02_large_patch14_clip_336.merged2b_ft_inat21"
|
||||
config = TimmWrapperConfig.from_pretrained(checkpoint)
|
||||
|
||||
self.assertIsInstance(config.id2label, dict)
|
||||
self.assertEqual(len(config.id2label), 10000)
|
||||
self.assertEqual(config.id2label[1], "Sabella spallanzanii")
|
||||
|
||||
self.assertIsInstance(config.label2id, dict)
|
||||
self.assertEqual(len(config.label2id), 10000)
|
||||
self.assertEqual(config.label2id["Sabella spallanzanii"], 1)
|
||||
|
||||
# test custom labels are provided
|
||||
checkpoint = "timm/resnet18.a1_in1k"
|
||||
config = TimmWrapperConfig.from_pretrained(checkpoint, num_labels=2)
|
||||
self.assertEqual(config.num_labels, 2)
|
||||
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
|
||||
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})
|
||||
|
||||
# test with provided id2label and label2id
|
||||
checkpoint = "timm/resnet18.a1_in1k"
|
||||
config = TimmWrapperConfig.from_pretrained(
|
||||
checkpoint, num_labels=2, id2label={0: "LABEL_0", 1: "LABEL_1"}, label2id={"LABEL_0": 0, "LABEL_1": 1}
|
||||
)
|
||||
self.assertEqual(config.num_labels, 2)
|
||||
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
|
||||
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})
|
||||
|
||||
# test save load
|
||||
checkpoint = "timm/resnet18.a1_in1k"
|
||||
config = TimmWrapperConfig.from_pretrained(checkpoint)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
config.save_pretrained(tmpdirname)
|
||||
restored_config = TimmWrapperConfig.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertEqual(config.num_labels, restored_config.num_labels)
|
||||
self.assertEqual(config.id2label, restored_config.id2label)
|
||||
self.assertEqual(config.label2id, restored_config.label2id)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
||||
Reference in New Issue
Block a user