Timm wrapper label names (#35553)

* Add timm wrapper label names mapping

* Add index to classification pipeline

* Revert adding index for pipelines

* Add custom model check for loading timm labels

* Add tests for labels

* [run-slow] timm_wrapper

* Add note regarding label2id mapping
This commit is contained in:
Pavel Iakubovskii
2025-01-08 14:09:46 +00:00
committed by GitHub
parent f1639ea51d
commit 59e5b3f01b
2 changed files with 80 additions and 2 deletions

View File

@@ -18,7 +18,11 @@
from typing import Any, Dict from typing import Any, Dict
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import is_timm_available, logging
if is_timm_available():
from timm.data import ImageNetInfo, infer_imagenet_subset
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -33,6 +37,9 @@ class TimmWrapperConfig(PretrainedConfig):
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
imagenet models is set to `None` due to occlusions in the label descriptions.
Args: Args:
initializer_range (`float`, *optional*, defaults to 0.02): initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
@@ -60,10 +67,30 @@ class TimmWrapperConfig(PretrainedConfig):
@classmethod @classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs): def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
label_names = config_dict.get("label_names", None)
is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
# if no labels added to config, use imagenet labeller in timm
if label_names is None and not is_custom_model:
imagenet_subset = infer_imagenet_subset(config_dict)
if imagenet_subset:
dataset_info = ImageNetInfo(imagenet_subset)
synsets = dataset_info.label_names()
label_descriptions = dataset_info.label_descriptions(as_dict=True)
label_names = [label_descriptions[synset] for synset in synsets]
if label_names is not None and not is_custom_model:
kwargs["id2label"] = dict(enumerate(label_names))
# if all label names are unique, create label2id mapping as well
if len(set(label_names)) == len(label_names):
kwargs["label2id"] = {name: i for i, name in enumerate(label_names)}
else:
kwargs["label2id"] = None
# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict. # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config # We are removing these attributes in order to have the native `transformers` num_labels attribute in config
# and to avoid duplicate attributes # and to avoid duplicate attributes
num_labels_in_kwargs = kwargs.pop("num_labels", None) num_labels_in_kwargs = kwargs.pop("num_labels", None)
num_labels_in_dict = config_dict.pop("num_classes", None) num_labels_in_dict = config_dict.pop("num_classes", None)
@@ -80,6 +107,9 @@ class TimmWrapperConfig(PretrainedConfig):
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
output = super().to_dict() output = super().to_dict()
output["num_classes"] = self.num_labels output["num_classes"] = self.num_labels
output["label_names"] = list(self.id2label.values())
output.pop("id2label", None)
output.pop("label2id", None)
return output return output

View File

@@ -200,6 +200,54 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
output = model(**inputs_dict, do_pooling=True) output = model(**inputs_dict, do_pooling=True)
self.assertIsNotNone(output.pooler_output) self.assertIsNotNone(output.pooler_output)
def test_timm_config_labels(self):
# test timm config with no labels
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(checkpoint)
self.assertIsNone(config.label2id)
self.assertIsInstance(config.id2label, dict)
self.assertEqual(len(config.id2label), 1000)
self.assertEqual(config.id2label[1], "goldfish, Carassius auratus")
# test timm config with labels in config
checkpoint = "timm/eva02_large_patch14_clip_336.merged2b_ft_inat21"
config = TimmWrapperConfig.from_pretrained(checkpoint)
self.assertIsInstance(config.id2label, dict)
self.assertEqual(len(config.id2label), 10000)
self.assertEqual(config.id2label[1], "Sabella spallanzanii")
self.assertIsInstance(config.label2id, dict)
self.assertEqual(len(config.label2id), 10000)
self.assertEqual(config.label2id["Sabella spallanzanii"], 1)
# test custom labels are provided
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(checkpoint, num_labels=2)
self.assertEqual(config.num_labels, 2)
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})
# test with provided id2label and label2id
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(
checkpoint, num_labels=2, id2label={0: "LABEL_0", 1: "LABEL_1"}, label2id={"LABEL_0": 0, "LABEL_1": 1}
)
self.assertEqual(config.num_labels, 2)
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})
# test save load
checkpoint = "timm/resnet18.a1_in1k"
config = TimmWrapperConfig.from_pretrained(checkpoint)
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_pretrained(tmpdirname)
restored_config = TimmWrapperConfig.from_pretrained(tmpdirname)
self.assertEqual(config.num_labels, restored_config.num_labels)
self.assertEqual(config.id2label, restored_config.id2label)
self.assertEqual(config.label2id, restored_config.label2id)
# We will verify our results on an image of cute cats # We will verify our results on an image of cute cats
def prepare_img(): def prepare_img():