Add support for fine-tuning CLIP-like models using contrastive-image-text example (#29070)
* add support for siglip and chinese-clip model training with contrastive-image-text example * codebase fixups
This commit is contained in:
committed by
GitHub
parent
0996a10077
commit
ee3af60be0
@@ -54,6 +54,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("camembert", "CamembertConfig"),
|
("camembert", "CamembertConfig"),
|
||||||
("canine", "CanineConfig"),
|
("canine", "CanineConfig"),
|
||||||
("chinese_clip", "ChineseCLIPConfig"),
|
("chinese_clip", "ChineseCLIPConfig"),
|
||||||
|
("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
|
||||||
("clap", "ClapConfig"),
|
("clap", "ClapConfig"),
|
||||||
("clip", "CLIPConfig"),
|
("clip", "CLIPConfig"),
|
||||||
("clip_vision_model", "CLIPVisionConfig"),
|
("clip_vision_model", "CLIPVisionConfig"),
|
||||||
@@ -512,6 +513,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("camembert", "CamemBERT"),
|
("camembert", "CamemBERT"),
|
||||||
("canine", "CANINE"),
|
("canine", "CANINE"),
|
||||||
("chinese_clip", "Chinese-CLIP"),
|
("chinese_clip", "Chinese-CLIP"),
|
||||||
|
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
|
||||||
("clap", "CLAP"),
|
("clap", "CLAP"),
|
||||||
("clip", "CLIP"),
|
("clip", "CLIP"),
|
||||||
("clip_vision_model", "CLIPVisionModel"),
|
("clip_vision_model", "CLIPVisionModel"),
|
||||||
@@ -773,6 +775,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
|||||||
("xclip", "x_clip"),
|
("xclip", "x_clip"),
|
||||||
("clip_vision_model", "clip"),
|
("clip_vision_model", "clip"),
|
||||||
("siglip_vision_model", "siglip"),
|
("siglip_vision_model", "siglip"),
|
||||||
|
("chinese_clip_vision_model", "chinese_clip"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("camembert", "CamembertModel"),
|
("camembert", "CamembertModel"),
|
||||||
("canine", "CanineModel"),
|
("canine", "CanineModel"),
|
||||||
("chinese_clip", "ChineseCLIPModel"),
|
("chinese_clip", "ChineseCLIPModel"),
|
||||||
|
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
|
||||||
("clap", "ClapModel"),
|
("clap", "ClapModel"),
|
||||||
("clip", "CLIPModel"),
|
("clip", "CLIPModel"),
|
||||||
("clip_vision_model", "CLIPVisionModel"),
|
("clip_vision_model", "CLIPVisionModel"),
|
||||||
|
|||||||
@@ -171,8 +171,7 @@ class ChineseCLIPVisionConfig(PretrainedConfig):
|
|||||||
This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate an
|
This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate an
|
||||||
ChineseCLIP model according to the specified arguments, defining the model architecture. Instantiating a
|
ChineseCLIP model according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
configuration with the defaults will yield a similar configuration to that of the ChineseCLIP
|
configuration with the defaults will yield a similar configuration to that of the ChineseCLIP
|
||||||
[OFA-Sys/chinese-clip-vit-base-patch16](https:
|
[OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
|
||||||
//huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
|
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|||||||
@@ -18,11 +18,19 @@
|
|||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto.configuration_auto import AutoConfig
|
||||||
|
from ..chinese_clip.configuration_chinese_clip import ChineseCLIPVisionConfig
|
||||||
from ..clip.configuration_clip import CLIPVisionConfig
|
from ..clip.configuration_clip import CLIPVisionConfig
|
||||||
|
from ..siglip.configuration_siglip import SiglipVisionConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
VISION_MODEL_CONFIGS = {
|
||||||
|
"clip_vision_model": CLIPVisionConfig,
|
||||||
|
"chinese_clip_vision_model": ChineseCLIPVisionConfig,
|
||||||
|
"siglip_vision_model": SiglipVisionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class VisionTextDualEncoderConfig(PretrainedConfig):
|
class VisionTextDualEncoderConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
@@ -85,12 +93,13 @@ class VisionTextDualEncoderConfig(PretrainedConfig):
|
|||||||
vision_model_type = vision_config.pop("model_type")
|
vision_model_type = vision_config.pop("model_type")
|
||||||
text_model_type = text_config.pop("model_type")
|
text_model_type = text_config.pop("model_type")
|
||||||
|
|
||||||
if vision_model_type == "clip":
|
vision_config_class = VISION_MODEL_CONFIGS.get(vision_model_type)
|
||||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
|
if vision_config_class is not None:
|
||||||
elif vision_model_type == "clip_vision_model":
|
self.vision_config = vision_config_class(**vision_config)
|
||||||
self.vision_config = CLIPVisionConfig(**vision_config)
|
|
||||||
else:
|
else:
|
||||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
||||||
|
if hasattr(self.vision_config, "vision_config"):
|
||||||
|
self.vision_config = self.vision_config.vision_config
|
||||||
|
|
||||||
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
||||||
|
|
||||||
|
|||||||
@@ -1070,6 +1070,7 @@ MODELS_NOT_IN_README = [
|
|||||||
"VisionTextDualEncoder",
|
"VisionTextDualEncoder",
|
||||||
"CLIPVisionModel",
|
"CLIPVisionModel",
|
||||||
"SiglipVisionModel",
|
"SiglipVisionModel",
|
||||||
|
"ChineseCLIPVisionModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Template for new entries to add in the main README when we have missing models.
|
# Template for new entries to add in the main README when we have missing models.
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ MODEL_NAMES_WITH_SAME_CONFIG = {
|
|||||||
"XLS-R": "Wav2Vec2",
|
"XLS-R": "Wav2Vec2",
|
||||||
"XLSR-Wav2Vec2": "Wav2Vec2",
|
"XLSR-Wav2Vec2": "Wav2Vec2",
|
||||||
}
|
}
|
||||||
MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel"]
|
MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel", "ChineseCLIPVisionModel"]
|
||||||
|
|
||||||
|
|
||||||
def get_model_table_from_auto_modules() -> str:
|
def get_model_table_from_auto_modules() -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user