Fix loading CLIPVisionConfig and CLIPTextConfig (#16198)
* override from_pretrained * add tests * remove docstrings * fix typo * Trigger CI
This commit is contained in:
@@ -15,6 +15,8 @@
|
|||||||
""" CLIP model configuration"""
|
""" CLIP model configuration"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -118,6 +120,23 @@ class CLIPTextConfig(PretrainedConfig):
|
|||||||
self.initializer_factor = initializer_factor
|
self.initializer_factor = initializer_factor
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||||
|
|
||||||
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
# get the text config dict if we are loading from CLIPConfig
|
||||||
|
if config_dict.get("model_type") == "clip":
|
||||||
|
config_dict = config_dict["text_config"]
|
||||||
|
|
||||||
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||||
|
logger.warning(
|
||||||
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||||
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionConfig(PretrainedConfig):
|
class CLIPVisionConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
@@ -205,6 +224,23 @@ class CLIPVisionConfig(PretrainedConfig):
|
|||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.hidden_act = hidden_act
|
self.hidden_act = hidden_act
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||||
|
|
||||||
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
# get the vision config dict if we are loading from CLIPConfig
|
||||||
|
if config_dict.get("model_type") == "clip":
|
||||||
|
config_dict = config_dict["vision_config"]
|
||||||
|
|
||||||
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||||
|
logger.warning(
|
||||||
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||||
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CLIPConfig(PretrainedConfig):
|
class CLIPConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -588,6 +588,21 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
def test_load_vision_text_config(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# Save CLIPConfig and check if we can load CLIPVisionConfig from it
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
config.save_pretrained(tmp_dir_name)
|
||||||
|
vision_config = CLIPVisionConfig.from_pretrained(tmp_dir_name)
|
||||||
|
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||||
|
|
||||||
|
# Save CLIPConfig and check if we can load CLIPTextConfig from it
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
config.save_pretrained(tmp_dir_name)
|
||||||
|
text_config = CLIPTextConfig.from_pretrained(tmp_dir_name)
|
||||||
|
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||||
|
|
||||||
# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
|
# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
|
||||||
@is_pt_tf_cross_test
|
@is_pt_tf_cross_test
|
||||||
def test_pt_tf_model_equivalence(self):
|
def test_pt_tf_model_equivalence(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user