Timm wrapper label names (#35553)
* Add timm wrapper label names mapping * Add index to classification pipeline * Revert adding index for pipelines * Add custom model check for loading timm labels * Add tests for labels * [run-slow] timm_wrapper * Add note regarding label2id mapping
This commit is contained in:
committed by
GitHub
parent
f1639ea51d
commit
59e5b3f01b
@@ -18,7 +18,11 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import is_timm_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_timm_available():
|
||||||
|
from timm.data import ImageNetInfo, infer_imagenet_subset
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -33,6 +37,9 @@ class TimmWrapperConfig(PretrainedConfig):
|
|||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
|
||||||
|
imagenet models is set to `None` due to occlusions in the label descriptions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
@@ -60,10 +67,30 @@ class TimmWrapperConfig(PretrainedConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
|
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
|
||||||
|
label_names = config_dict.get("label_names", None)
|
||||||
|
is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
|
||||||
|
|
||||||
|
# if no labels added to config, use imagenet labeller in timm
|
||||||
|
if label_names is None and not is_custom_model:
|
||||||
|
imagenet_subset = infer_imagenet_subset(config_dict)
|
||||||
|
if imagenet_subset:
|
||||||
|
dataset_info = ImageNetInfo(imagenet_subset)
|
||||||
|
synsets = dataset_info.label_names()
|
||||||
|
label_descriptions = dataset_info.label_descriptions(as_dict=True)
|
||||||
|
label_names = [label_descriptions[synset] for synset in synsets]
|
||||||
|
|
||||||
|
if label_names is not None and not is_custom_model:
|
||||||
|
kwargs["id2label"] = dict(enumerate(label_names))
|
||||||
|
|
||||||
|
# if all label names are unique, create label2id mapping as well
|
||||||
|
if len(set(label_names)) == len(label_names):
|
||||||
|
kwargs["label2id"] = {name: i for i, name in enumerate(label_names)}
|
||||||
|
else:
|
||||||
|
kwargs["label2id"] = None
|
||||||
|
|
||||||
# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
|
# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
|
||||||
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config
|
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config
|
||||||
# and to avoid duplicate attributes
|
# and to avoid duplicate attributes
|
||||||
|
|
||||||
num_labels_in_kwargs = kwargs.pop("num_labels", None)
|
num_labels_in_kwargs = kwargs.pop("num_labels", None)
|
||||||
num_labels_in_dict = config_dict.pop("num_classes", None)
|
num_labels_in_dict = config_dict.pop("num_classes", None)
|
||||||
|
|
||||||
@@ -80,6 +107,9 @@ class TimmWrapperConfig(PretrainedConfig):
|
|||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
output = super().to_dict()
|
output = super().to_dict()
|
||||||
output["num_classes"] = self.num_labels
|
output["num_classes"] = self.num_labels
|
||||||
|
output["label_names"] = list(self.id2label.values())
|
||||||
|
output.pop("id2label", None)
|
||||||
|
output.pop("label2id", None)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -200,6 +200,54 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
output = model(**inputs_dict, do_pooling=True)
|
output = model(**inputs_dict, do_pooling=True)
|
||||||
self.assertIsNotNone(output.pooler_output)
|
self.assertIsNotNone(output.pooler_output)
|
||||||
|
|
||||||
|
def test_timm_config_labels(self):
|
||||||
|
# test timm config with no labels
|
||||||
|
checkpoint = "timm/resnet18.a1_in1k"
|
||||||
|
config = TimmWrapperConfig.from_pretrained(checkpoint)
|
||||||
|
self.assertIsNone(config.label2id)
|
||||||
|
self.assertIsInstance(config.id2label, dict)
|
||||||
|
self.assertEqual(len(config.id2label), 1000)
|
||||||
|
self.assertEqual(config.id2label[1], "goldfish, Carassius auratus")
|
||||||
|
|
||||||
|
# test timm config with labels in config
|
||||||
|
checkpoint = "timm/eva02_large_patch14_clip_336.merged2b_ft_inat21"
|
||||||
|
config = TimmWrapperConfig.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
self.assertIsInstance(config.id2label, dict)
|
||||||
|
self.assertEqual(len(config.id2label), 10000)
|
||||||
|
self.assertEqual(config.id2label[1], "Sabella spallanzanii")
|
||||||
|
|
||||||
|
self.assertIsInstance(config.label2id, dict)
|
||||||
|
self.assertEqual(len(config.label2id), 10000)
|
||||||
|
self.assertEqual(config.label2id["Sabella spallanzanii"], 1)
|
||||||
|
|
||||||
|
# test custom labels are provided
|
||||||
|
checkpoint = "timm/resnet18.a1_in1k"
|
||||||
|
config = TimmWrapperConfig.from_pretrained(checkpoint, num_labels=2)
|
||||||
|
self.assertEqual(config.num_labels, 2)
|
||||||
|
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
|
||||||
|
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})
|
||||||
|
|
||||||
|
# test with provided id2label and label2id
|
||||||
|
checkpoint = "timm/resnet18.a1_in1k"
|
||||||
|
config = TimmWrapperConfig.from_pretrained(
|
||||||
|
checkpoint, num_labels=2, id2label={0: "LABEL_0", 1: "LABEL_1"}, label2id={"LABEL_0": 0, "LABEL_1": 1}
|
||||||
|
)
|
||||||
|
self.assertEqual(config.num_labels, 2)
|
||||||
|
self.assertEqual(config.id2label, {0: "LABEL_0", 1: "LABEL_1"})
|
||||||
|
self.assertEqual(config.label2id, {"LABEL_0": 0, "LABEL_1": 1})
|
||||||
|
|
||||||
|
# test save load
|
||||||
|
checkpoint = "timm/resnet18.a1_in1k"
|
||||||
|
config = TimmWrapperConfig.from_pretrained(checkpoint)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
config.save_pretrained(tmpdirname)
|
||||||
|
restored_config = TimmWrapperConfig.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
self.assertEqual(config.num_labels, restored_config.num_labels)
|
||||||
|
self.assertEqual(config.id2label, restored_config.id2label)
|
||||||
|
self.assertEqual(config.label2id, restored_config.label2id)
|
||||||
|
|
||||||
|
|
||||||
# We will verify our results on an image of cute cats
|
# We will verify our results on an image of cute cats
|
||||||
def prepare_img():
|
def prepare_img():
|
||||||
|
|||||||
Reference in New Issue
Block a user