Deal with nested configs better in base class (#25237)

* Deal better with nested configs

* Fixes

* More fixes

* Fix last test

* Clean up existing configs

* Remove hack in MPT Config

* Update src/transformers/configuration_utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Fix setting a nested config via dict in the kwargs

* Adapt common test

* Add test for nested config load with dict

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2023-08-04 14:56:09 +02:00
committed by GitHub
parent aeb5a08abd
commit 29f04002e6
40 changed files with 62 additions and 566 deletions

View File

@@ -762,6 +762,10 @@ class PretrainedConfig(PushToHubMixin):
to_remove = [] to_remove = []
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(config, key): if hasattr(config, key):
current_attr = getattr(config, key)
# To authorize passing a custom subconfig as kwarg in models that have nested configs.
if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
value = current_attr.__class__(**value)
setattr(config, key, value) setattr(config, key, value)
if key != "torch_dtype": if key != "torch_dtype":
to_remove.append(key) to_remove.append(key)
@@ -823,6 +827,18 @@ class PretrainedConfig(PushToHubMixin):
# only serialize values that differ from the default config # only serialize values that differ from the default config
for key, value in config_dict.items(): for key, value in config_dict.items():
if ( if (
isinstance(getattr(self, key, None), PretrainedConfig)
and key in class_config_dict
and isinstance(class_config_dict[key], dict)
):
# For nested configs we need to clean the diff recursively
diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
if "model_type" in value:
# Needs to be set even if it's not in the diff
diff["model_type"] = value["model_type"]
if len(diff) > 0:
serializable_config_dict[key] = diff
elif (
key not in default_config_dict key not in default_config_dict
or key == "transformers_version" or key == "transformers_version"
or value != default_config_dict[key] or value != default_config_dict[key]
@@ -859,6 +875,14 @@ class PretrainedConfig(PushToHubMixin):
# Transformers version when serializing the model # Transformers version when serializing the model
output["transformers_version"] = __version__ output["transformers_version"] = __version__
for key, value in output.items():
# Deal with nested configs like CLIP
if isinstance(value, PretrainedConfig):
value = value.to_dict()
del value["transformers_version"]
output[key] = value
if hasattr(self, "quantization_config"): if hasattr(self, "quantization_config"):
output["quantization_config"] = ( output["quantization_config"] = (
self.quantization_config.to_dict() self.quantization_config.to_dict()
@@ -1020,6 +1044,24 @@ def get_configuration_file(configuration_files: List[str]) -> str:
return configuration_file return configuration_file
def recursive_diff_dict(dict_a, dict_b, config_obj=None):
"""
Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
values from `dict_a` that are different from values in `dict_b`.
"""
diff = {}
default = config_obj.__class__().to_dict() if config_obj is not None else {}
for key, value in dict_a.items():
obj_value = getattr(config_obj, str(key), None)
if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
if len(diff_value) > 0:
diff[key] = diff_value
elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
diff[key] = value
return diff
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub) PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
if PretrainedConfig.push_to_hub.__doc__ is not None: if PretrainedConfig.push_to_hub.__doc__ is not None:
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format( PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" ALIGN model configuration""" """ ALIGN model configuration"""
import copy
import os import os
from typing import TYPE_CHECKING, List, Union from typing import TYPE_CHECKING, List, Union
@@ -344,7 +343,6 @@ class AlignConfig(PretrainedConfig):
```""" ```"""
model_type = "align" model_type = "align"
is_composition = True
def __init__( def __init__(
self, self,
@@ -383,16 +381,3 @@ class AlignConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" AltCLIP model configuration""" """ AltCLIP model configuration"""
import copy
import os import os
from typing import Union from typing import Union
@@ -291,7 +290,6 @@ class AltCLIPConfig(PretrainedConfig):
```""" ```"""
model_type = "altclip" model_type = "altclip"
is_composition = True
def __init__( def __init__(
self, text_config=None, vision_config=None, projection_dim=768, logit_scale_init_value=2.6592, **kwargs self, text_config=None, vision_config=None, projection_dim=768, logit_scale_init_value=2.6592, **kwargs
@@ -392,16 +390,3 @@ class AltCLIPConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" BARK model configuration""" """ BARK model configuration"""
import copy
import os import os
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
@@ -271,7 +270,6 @@ class BarkConfig(PretrainedConfig):
""" """
model_type = "bark" model_type = "bark"
is_composition = True
def __init__( def __init__(
self, self,
@@ -329,20 +327,3 @@ class BarkConfig(PretrainedConfig):
codec_config=codec_config.to_dict(), codec_config=codec_config.to_dict(),
**kwargs, **kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["semantic_config"] = self.semantic_config.to_dict()
output["coarse_acoustics_config"] = self.coarse_acoustics_config.to_dict()
output["fine_acoustics_config"] = self.fine_acoustics_config.to_dict()
output["codec_config"] = self.codec_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Blip model configuration""" """ Blip model configuration"""
import copy
import os import os
from typing import Union from typing import Union
@@ -325,7 +324,6 @@ class BlipConfig(PretrainedConfig):
```""" ```"""
model_type = "blip" model_type = "blip"
is_composition = True
def __init__( def __init__(
self, self,
@@ -368,16 +366,3 @@ class BlipConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" BLIP-2 model configuration""" """ BLIP-2 model configuration"""
import copy
import os import os
from typing import Union from typing import Union
@@ -302,7 +301,6 @@ class Blip2Config(PretrainedConfig):
```""" ```"""
model_type = "blip-2" model_type = "blip-2"
is_composition = True
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -355,17 +353,3 @@ class Blip2Config(PretrainedConfig):
text_config=text_config.to_dict(), text_config=text_config.to_dict(),
**kwargs, **kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["qformer_config"] = self.qformer_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" BridgeTower model configuration""" """ BridgeTower model configuration"""
import copy
import os import os
from typing import Union from typing import Union
@@ -349,16 +348,3 @@ class BridgeTowerConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Chinese-CLIP model configuration""" """ Chinese-CLIP model configuration"""
import copy
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
@@ -314,7 +313,6 @@ class ChineseCLIPConfig(PretrainedConfig):
```""" ```"""
model_type = "chinese_clip" model_type = "chinese_clip"
is_composition = True
def __init__( def __init__(
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
@@ -417,19 +415,6 @@ class ChineseCLIPConfig(PretrainedConfig):
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class ChineseCLIPOnnxConfig(OnnxConfig): class ChineseCLIPOnnxConfig(OnnxConfig):
@property @property

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" CLAP model configuration""" """ CLAP model configuration"""
import copy
import os import os
from typing import Union from typing import Union
@@ -382,7 +381,6 @@ class ClapConfig(PretrainedConfig):
```""" ```"""
model_type = "clap" model_type = "clap"
is_composition = True
def __init__( def __init__(
self, self,
@@ -431,16 +429,3 @@ class ClapConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), audio_config=audio_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), audio_config=audio_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["audio_config"] = self.audio_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" CLIP model configuration""" """ CLIP model configuration"""
import copy
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
@@ -298,7 +297,6 @@ class CLIPConfig(PretrainedConfig):
```""" ```"""
model_type = "clip" model_type = "clip"
is_composition = True
def __init__( def __init__(
self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
@@ -400,19 +398,6 @@ class CLIPConfig(PretrainedConfig):
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class CLIPOnnxConfig(OnnxConfig): class CLIPOnnxConfig(OnnxConfig):
@property @property

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" CLIPSeg model configuration""" """ CLIPSeg model configuration"""
import copy
import os import os
from typing import Union from typing import Union
@@ -302,7 +301,6 @@ class CLIPSegConfig(PretrainedConfig):
```""" ```"""
model_type = "clipseg" model_type = "clipseg"
is_composition = True
def __init__( def __init__(
self, self,
@@ -424,16 +422,3 @@ class CLIPSegConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Conditional DETR model configuration""" """ Conditional DETR model configuration"""
import copy
from collections import OrderedDict from collections import OrderedDict
from typing import Mapping from typing import Mapping
@@ -238,19 +237,6 @@ class ConditionalDetrConfig(PretrainedConfig):
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if self.backbone_config is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class ConditionalDetrOnnxConfig(OnnxConfig): class ConditionalDetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11") torch_onnx_minimum_version = version.parse("1.11")

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Deformable DETR model configuration""" """ Deformable DETR model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -261,16 +260,3 @@ class DeformableDetrConfig(PretrainedConfig):
@property @property
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if self.backbone_config is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" DETA model configuration""" """ DETA model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -230,13 +229,3 @@ class DetaConfig(PretrainedConfig):
@property @property
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
""" DETR model configuration""" """ DETR model configuration"""
import copy
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Mapping from typing import Mapping
from packaging import version from packaging import version
@@ -248,17 +247,6 @@ class DetrConfig(PretrainedConfig):
""" """
return cls(backbone_config=backbone_config, **kwargs) return cls(backbone_config=backbone_config, **kwargs)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class DetrOnnxConfig(OnnxConfig): class DetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11") torch_onnx_minimum_version = version.parse("1.11")

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -104,16 +103,3 @@ class EncoderDecoderConfig(PretrainedConfig):
decoder_config.add_cross_attention = True decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" FLAVA model configurations""" """ FLAVA model configurations"""
import copy
import os import os
from typing import Any, Dict, Union from typing import Any, Dict, Union
@@ -536,7 +535,6 @@ class FlavaConfig(PretrainedConfig):
""" """
model_type = "flava" model_type = "flava"
is_composition = True
def __init__( def __init__(
self, self,
@@ -764,18 +762,3 @@ class FlavaConfig(PretrainedConfig):
image_codebook_config=image_codebook_config.to_dict(), image_codebook_config=image_codebook_config.to_dict(),
**kwargs, **kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["image_config"] = self.image_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["multimodal_config"] = self.multimodal_config.to_dict()
output["image_codebook_config"] = self.image_codebook_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -15,8 +15,6 @@
""" FSMT configuration""" """ FSMT configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -216,15 +214,3 @@ class FSMTConfig(PretrainedConfig):
early_stopping=early_stopping, early_stopping=early_stopping,
**common_kwargs, **common_kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import os import os
from typing import Union from typing import Union
@@ -239,13 +238,3 @@ class GitConfig(PretrainedConfig):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" GroupViT model configuration""" """ GroupViT model configuration"""
import copy
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
@@ -296,7 +295,6 @@ class GroupViTConfig(PretrainedConfig):
""" """
model_type = "groupvit" model_type = "groupvit"
is_composition = True
def __init__( def __init__(
self, self,
@@ -407,19 +405,6 @@ class GroupViTConfig(PretrainedConfig):
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class GroupViTOnnxConfig(OnnxConfig): class GroupViTOnnxConfig(OnnxConfig):
@property @property

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" InstructBLIP model configuration""" """ InstructBLIP model configuration"""
import copy
import os import os
from typing import Union from typing import Union
@@ -305,7 +304,6 @@ class InstructBlipConfig(PretrainedConfig):
```""" ```"""
model_type = "instructblip" model_type = "instructblip"
is_composition = True
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs): def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -358,17 +356,3 @@ class InstructBlipConfig(PretrainedConfig):
text_config=text_config.to_dict(), text_config=text_config.to_dict(),
**kwargs, **kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["qformer_config"] = self.qformer_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Jukebox configuration""" """ Jukebox configuration"""
import copy
import os import os
from typing import List, Union from typing import List, Union
@@ -369,18 +368,6 @@ class JukeboxPriorConfig(PretrainedConfig):
return cls.from_dict(config_dict, **kwargs) return cls.from_dict(config_dict, **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder_config"] = self.encoder_config.to_dict() if self.encoder_config is not None else None
output["model_type"] = self.__class__.model_type
return output
class JukeboxVQVAEConfig(PretrainedConfig): class JukeboxVQVAEConfig(PretrainedConfig):
""" """
@@ -561,7 +548,6 @@ class JukeboxConfig(PretrainedConfig):
""" """
model_type = "jukebox" model_type = "jukebox"
is_composition = True
def __init__( def __init__(
self, self,
@@ -620,18 +606,3 @@ class JukeboxConfig(PretrainedConfig):
""" """
prior_config_list = [config.to_dict() for config in prior_configs] prior_config_list = [config.to_dict() for config in prior_configs]
return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs) return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
for i, config in enumerate(output.pop("prior_configs")):
output[f"prior_{i}"] = config.to_dict()
output["vqvae_config"] = self.vqvae_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Mask2Former model configuration""" """ Mask2Former model configuration"""
import copy
from typing import Dict, List, Optional from typing import Dict, List, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
@@ -230,15 +229,3 @@ class Mask2FormerConfig(PretrainedConfig):
backbone_config=backbone_config, backbone_config=backbone_config,
**kwargs, **kwargs,
) )
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" MaskFormer model configuration""" """ MaskFormer model configuration"""
import copy
from typing import Dict, Optional from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
@@ -200,16 +199,3 @@ class MaskFormerConfig(PretrainedConfig):
decoder_config=decoder_config, decoder_config=decoder_config,
**kwargs, **kwargs,
) )
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["decoder_config"] = self.decoder_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Mpt configuration""" """ Mpt configuration"""
import copy
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
@@ -197,7 +196,6 @@ class MptConfig(PretrainedConfig):
"hidden_size": "d_model", "hidden_size": "d_model",
"num_hidden_layers": "n_layers", "num_hidden_layers": "n_layers",
} }
is_composition = True
def __init__( def __init__(
self, self,
@@ -222,6 +220,11 @@ class MptConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
**kwargs, **kwargs,
): ):
if attn_config is None:
self.attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = MptAttentionConfig(**attn_config)
else:
self.attn_config = attn_config self.attn_config = attn_config
self.d_model = d_model self.d_model = d_model
self.n_heads = n_heads self.n_heads = n_heads
@@ -242,35 +245,3 @@ class MptConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.initializer_range = initializer_range self.initializer_range = initializer_range
super().__init__(**kwargs) super().__init__(**kwargs)
@property
def attn_config(self):
return self._attn_config
@attn_config.setter
def attn_config(self, attn_config):
if attn_config is None:
self._attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self._attn_config = MptAttentionConfig(**attn_config)
elif isinstance(attn_config, MptAttentionConfig):
self._attn_config = attn_config
else:
raise ValueError(
f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}"
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["attn_config"] = (
self._attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
)
del output["_attn_config"]
output["model_type"] = self.__class__.model_type
return output

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" MusicGen model configuration""" """ MusicGen model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -227,17 +226,3 @@ class MusicgenConfig(PretrainedConfig):
decoder=decoder_config.to_dict(), decoder=decoder_config.to_dict(),
**kwargs, **kwargs,
) )
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_encoder"] = self.text_encoder.to_dict()
output["audio_encoder"] = self.audio_encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""OneFormer model configuration""" """OneFormer model configuration"""
import copy
from typing import Dict, Optional from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
@@ -250,13 +249,3 @@ class OneFormerConfig(PretrainedConfig):
self.num_hidden_layers = decoder_layers self.num_hidden_layers = decoder_layers
super().__init__(**kwargs) super().__init__(**kwargs)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" OWL-ViT model configuration""" """ OWL-ViT model configuration"""
import copy
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
@@ -274,7 +273,6 @@ class OwlViTConfig(PretrainedConfig):
""" """
model_type = "owlvit" model_type = "owlvit"
is_composition = True
def __init__( def __init__(
self, self,
@@ -332,19 +330,6 @@ class OwlViTConfig(PretrainedConfig):
return cls.from_dict(config_dict, **kwargs) return cls.from_dict(config_dict, **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class OwlViTOnnxConfig(OnnxConfig): class OwlViTOnnxConfig(OnnxConfig):
@property @property

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Pix2Struct model configuration""" """ Pix2Struct model configuration"""
import copy
import os import os
from typing import Union from typing import Union
@@ -338,7 +337,6 @@ class Pix2StructConfig(PretrainedConfig):
```""" ```"""
model_type = "pix2struct" model_type = "pix2struct"
is_composition = True
def __init__( def __init__(
self, self,
@@ -389,16 +387,3 @@ class Pix2StructConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" RAG model configuration""" """ RAG model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import add_start_docstrings from ...utils import add_start_docstrings
@@ -179,16 +178,3 @@ class RagConfig(PretrainedConfig):
[`EncoderDecoderConfig`]: An instance of a configuration object [`EncoderDecoderConfig`]: An instance of a configuration object
""" """
return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs) return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["question_encoder"] = self.question_encoder.to_dict()
output["generator"] = self.generator.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" SAM model configuration""" """ SAM model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -286,7 +285,6 @@ class SamConfig(PretrainedConfig):
```""" ```"""
model_type = "sam" model_type = "sam"
is_composition = True
def __init__( def __init__(
self, self,
@@ -312,17 +310,3 @@ class SamConfig(PretrainedConfig):
self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config) self.prompt_encoder_config = SamPromptEncoderConfig(**prompt_encoder_config)
self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config) self.mask_decoder_config = SamMaskDecoderConfig(**mask_decoder_config)
self.initializer_range = initializer_range self.initializer_range = initializer_range
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["prompt_encoder_config"] = self.prompt_encoder_config.to_dict()
output["mask_decoder_config"] = self.mask_decoder_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -106,16 +105,3 @@ class SpeechEncoderDecoderConfig(PretrainedConfig):
decoder_config.add_cross_attention = True decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -13,9 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Table Transformer model configuration""" """ Table Transformer model configuration"""
import copy
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Mapping from typing import Mapping
from packaging import version from packaging import version
@@ -237,17 +236,6 @@ class TableTransformerConfig(PretrainedConfig):
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
# Copied from transformers.models.detr.configuration_detr.DetrOnnxConfig # Copied from transformers.models.detr.configuration_detr.DetrOnnxConfig
class TableTransformerOnnxConfig(OnnxConfig): class TableTransformerOnnxConfig(OnnxConfig):

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" UperNet model configuration""" """ UperNet model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -108,13 +107,3 @@ class UperNetConfig(PretrainedConfig):
self.auxiliary_num_convs = auxiliary_num_convs self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input self.auxiliary_concat_input = auxiliary_concat_input
self.loss_ignore_index = loss_ignore_index self.loss_ignore_index = loss_ignore_index
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict
from packaging import version from packaging import version
@@ -114,19 +113,6 @@ class VisionEncoderDecoderConfig(PretrainedConfig):
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig): class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11") torch_onnx_minimum_version = version.parse("1.11")

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" VisionTextDualEncoder model configuration""" """ VisionTextDualEncoder model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -113,16 +112,3 @@ class VisionTextDualEncoderConfig(PretrainedConfig):
""" """
return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs) return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
""" ViT Hybrid model configuration""" """ ViT Hybrid model configuration"""
import copy
from typing import Dict
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -146,13 +144,3 @@ class ViTHybridConfig(PretrainedConfig):
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels self.num_channels = num_channels
self.qkv_bias = qkv_bias self.qkv_bias = qkv_bias
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" X-CLIP model configuration""" """ X-CLIP model configuration"""
import copy
import os import os
from typing import Union from typing import Union
@@ -299,7 +298,6 @@ class XCLIPConfig(PretrainedConfig):
""" """
model_type = "xclip" model_type = "xclip"
is_composition = True
def __init__( def __init__(
self, self,
@@ -417,16 +415,3 @@ class XCLIPConfig(PretrainedConfig):
""" """
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output

View File

@@ -118,7 +118,9 @@ class ConfigTester(object):
def check_config_can_be_init_without_params(self): def check_config_can_be_init_without_params(self):
if self.config_class.is_composition: if self.config_class.is_composition:
return with self.parent.assertRaises(ValueError):
config = self.config_class()
else:
config = self.config_class() config = self.config_class()
self.parent.assertIsNotNone(config) self.parent.assertIsNotNone(config)

View File

@@ -210,6 +210,13 @@ class ConfigTestUtils(unittest.TestCase):
f" {', '.join(keys_with_defaults)}." f" {', '.join(keys_with_defaults)}."
) )
def test_nested_config_load_from_dict(self):
config = AutoConfig.from_pretrained(
"hf-internal-testing/tiny-random-CLIPModel", text_config={"num_hidden_layers": 2}
)
self.assertNotIsInstance(config.text_config, dict)
self.assertEqual(config.text_config.__class__.__name__, "CLIPTextConfig")
def test_from_pretrained_subfolder(self): def test_from_pretrained_subfolder(self):
with self.assertRaises(OSError): with self.assertRaises(OSError):
# config is in subfolder, the following should not work without specifying the subfolder # config is in subfolder, the following should not work without specifying the subfolder