From 59e5b3f01b7773439671c3a827348ba87dc8b92a Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 8 Jan 2025 14:09:46 +0000 Subject: [PATCH] Timm wrapper label names (#35553) * Add timm wrapper label names mapping * Add index to classification pipeline * Revert adding index for pipelines * Add custom model check for loading timm labels * Add tests for labels * [run-slow] timm_wrapper * Add note regarding label2id mapping --- .../configuration_timm_wrapper.py | 34 ++++++++++++- .../test_modeling_timm_wrapper.py | 48 +++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py index 691a2b2b76..9562e28a8b 100644 --- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -18,7 +18,11 @@ from typing import Any, Dict from ...configuration_utils import PretrainedConfig -from ...utils import logging +from ...utils import is_timm_available, logging + + +if is_timm_available(): + from timm.data import ImageNetInfo, infer_imagenet_subset logger = logging.get_logger(__name__) @@ -33,6 +37,9 @@ class TimmWrapperConfig(PretrainedConfig): Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. + Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default + imagenet models is set to `None` due to occlusions in the label descriptions. + Args: initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. @@ -60,10 +67,30 @@ class TimmWrapperConfig(PretrainedConfig): @classmethod def from_dict(cls, config_dict: Dict[str, Any], **kwargs): + label_names = config_dict.get("label_names", None) + is_custom_model = "num_labels" in kwargs or "id2label" in kwargs + + # if no labels added to config, use imagenet labeller in timm + if label_names is None and not is_custom_model: + imagenet_subset = infer_imagenet_subset(config_dict) + if imagenet_subset: + dataset_info = ImageNetInfo(imagenet_subset) + synsets = dataset_info.label_names() + label_descriptions = dataset_info.label_descriptions(as_dict=True) + label_names = [label_descriptions[synset] for synset in synsets] + + if label_names is not None and not is_custom_model: + kwargs["id2label"] = dict(enumerate(label_names)) + + # if all label names are unique, create label2id mapping as well + if len(set(label_names)) == len(label_names): + kwargs["label2id"] = {name: i for i, name in enumerate(label_names)} + else: + kwargs["label2id"] = None + # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict. # We are removing these attributes in order to have the native `transformers` num_labels attribute in config # and to avoid duplicate attributes - num_labels_in_kwargs = kwargs.pop("num_labels", None) num_labels_in_dict = config_dict.pop("num_classes", None) @@ -80,6 +107,9 @@ class TimmWrapperConfig(PretrainedConfig): def to_dict(self) -> Dict[str, Any]: output = super().to_dict() output["num_classes"] = self.num_labels + output["label_names"] = list(self.id2label.values()) + output.pop("id2label", None) + output.pop("label2id", None) return output diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py index 6f63c0aa14..cf35d90518 100644 --- a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py +++ b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py @@ -200,6 +200,54 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC output = model(**inputs_dict, do_pooling=True) self.assertIsNotNone(output.pooler_output) + def test_timm_config_labels(self): + # test timm config with no labels + checkpoint = "timm/resnet18.a1_in1k" + config = TimmWrapperConfig.from_pretrained(checkpoint) + self.assertIsNone(config.label2id) + self.assertIsInstance(config.id2label, dict) + self.assertEqual(len(config.id2label), 1000) + self.assertEqual(config.id2label[1], "goldfish, Carassius auratus") + + # test timm config with labels in config + checkpoint = "timm/eva02_large_patch14_clip_336.merged2b_ft_inat21" + config = TimmWrapperConfig.from_pretrained(checkpoint) + + self.assertIsInstance(config.id2label, dict) + self.assertEqual(len(config.id2label), 10000) + self.assertEqual(config.id2label[1], "Sabella spallanzanii") + + self.assertIsInstance(config.label2id, dict) + self.assertEqual(len(config.label2id), 10000) + self.assertEqual(config.label2id["Sabella spallanzanii"], 1) + + # test custom labels are provided + checkpoint = "timm/resnet18.a1_in1k" + config = TimmWrapperConfig.from_pretrained(checkpoint, num_labels=2) + self.assertEqual(config.num_labels, 2) + self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"}) + self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1}) + + # test with provided id2label and label2id + checkpoint = "timm/resnet18.a1_in1k" + config = TimmWrapperConfig.from_pretrained( + checkpoint, num_labels=2, id2label={0: "LABEL_0", 1: "LABEL_1"}, label2id={"LABEL_0": 0, "LABEL_1": 1} + ) + self.assertEqual(config.num_labels, 2) + self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"}) + self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1}) + + # test save load + checkpoint = "timm/resnet18.a1_in1k" + config = TimmWrapperConfig.from_pretrained(checkpoint) + with tempfile.TemporaryDirectory() as tmpdirname: + config.save_pretrained(tmpdirname) + restored_config = TimmWrapperConfig.from_pretrained(tmpdirname) + + self.assertEqual(config.num_labels, restored_config.num_labels) + self.assertEqual(config.id2label, restored_config.id2label) + self.assertEqual(config.label2id, restored_config.label2id) + # We will verify our results on an image of cute cats def prepare_img():