Timm wrapper label names (#35553)
* Add timm wrapper label names mapping * Add index to classification pipeline * Revert adding index for pipelines * Add custom model check for loading timm labels * Add tests for labels * [run-slow] timm_wrapper * Add note regarding label2id mapping
This commit is contained in:
committed by
GitHub
parent
f1639ea51d
commit
59e5b3f01b
@@ -18,7 +18,11 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils import is_timm_available, logging
|
||||
|
||||
|
||||
if is_timm_available():
|
||||
from timm.data import ImageNetInfo, infer_imagenet_subset
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -33,6 +37,9 @@ class TimmWrapperConfig(PretrainedConfig):
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
|
||||
imagenet models is set to `None` due to occlusions in the label descriptions.
|
||||
|
||||
Args:
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
@@ -60,10 +67,30 @@ class TimmWrapperConfig(PretrainedConfig):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
|
||||
label_names = config_dict.get("label_names", None)
|
||||
is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
|
||||
|
||||
# if no labels added to config, use imagenet labeller in timm
|
||||
if label_names is None and not is_custom_model:
|
||||
imagenet_subset = infer_imagenet_subset(config_dict)
|
||||
if imagenet_subset:
|
||||
dataset_info = ImageNetInfo(imagenet_subset)
|
||||
synsets = dataset_info.label_names()
|
||||
label_descriptions = dataset_info.label_descriptions(as_dict=True)
|
||||
label_names = [label_descriptions[synset] for synset in synsets]
|
||||
|
||||
if label_names is not None and not is_custom_model:
|
||||
kwargs["id2label"] = dict(enumerate(label_names))
|
||||
|
||||
# if all label names are unique, create label2id mapping as well
|
||||
if len(set(label_names)) == len(label_names):
|
||||
kwargs["label2id"] = {name: i for i, name in enumerate(label_names)}
|
||||
else:
|
||||
kwargs["label2id"] = None
|
||||
|
||||
# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
|
||||
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config
|
||||
# and to avoid duplicate attributes
|
||||
|
||||
num_labels_in_kwargs = kwargs.pop("num_labels", None)
|
||||
num_labels_in_dict = config_dict.pop("num_classes", None)
|
||||
|
||||
@@ -80,6 +107,9 @@ class TimmWrapperConfig(PretrainedConfig):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
output = super().to_dict()
|
||||
output["num_classes"] = self.num_labels
|
||||
output["label_names"] = list(self.id2label.values())
|
||||
output.pop("id2label", None)
|
||||
output.pop("label2id", None)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@@ -200,6 +200,54 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
output = model(**inputs_dict, do_pooling=True)
|
||||
self.assertIsNotNone(output.pooler_output)
|
||||
|
||||
def test_timm_config_labels(self):
|
||||
# test timm config with no labels
|
||||
checkpoint = "timm/resnet18.a1_in1k"
|
||||
config = TimmWrapperConfig.from_pretrained(checkpoint)
|
||||
self.assertIsNone(config.label2id)
|
||||
self.assertIsInstance(config.id2label, dict)
|
||||
self.assertEqual(len(config.id2label), 1000)
|
||||
self.assertEqual(config.id2label[1], "goldfish, Carassius auratus")
|
||||
|
||||
# test timm config with labels in config
|
||||
checkpoint = "timm/eva02_large_patch14_clip_336.merged2b_ft_inat21"
|
||||
config = TimmWrapperConfig.from_pretrained(checkpoint)
|
||||
|
||||
self.assertIsInstance(config.id2label, dict)
|
||||
self.assertEqual(len(config.id2label), 10000)
|
||||
self.assertEqual(config.id2label[1], "Sabella spallanzanii")
|
||||
|
||||
self.assertIsInstance(config.label2id, dict)
|
||||
self.assertEqual(len(config.label2id), 10000)
|
||||
self.assertEqual(config.label2id["Sabella spallanzanii"], 1)
|
||||
|
||||
# test custom labels are provided
|
||||
checkpoint = "timm/resnet18.a1_in1k"
|
||||
config = TimmWrapperConfig.from_pretrained(checkpoint, num_labels=2)
|
||||
self.assertEqual(config.num_labels, 2)
|
||||
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
|
||||
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})
|
||||
|
||||
# test with provided id2label and label2id
|
||||
checkpoint = "timm/resnet18.a1_in1k"
|
||||
config = TimmWrapperConfig.from_pretrained(
|
||||
checkpoint, num_labels=2, id2label={0: "LABEL_0", 1: "LABEL_1"}, label2id={"LABEL_0": 0, "LABEL_1": 1}
|
||||
)
|
||||
self.assertEqual(config.num_labels, 2)
|
||||
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
|
||||
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})
|
||||
|
||||
# test save load
|
||||
checkpoint = "timm/resnet18.a1_in1k"
|
||||
config = TimmWrapperConfig.from_pretrained(checkpoint)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
config.save_pretrained(tmpdirname)
|
||||
restored_config = TimmWrapperConfig.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertEqual(config.num_labels, restored_config.num_labels)
|
||||
self.assertEqual(config.id2label, restored_config.id2label)
|
||||
self.assertEqual(config.label2id, restored_config.label2id)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
||||
Reference in New Issue
Block a user