Use name transformers_config in Keras serialization
Be explicit that this is config for the transformers package (as these layers may coexist with other custom stuff in a Keras model, plus the Keras container itself is called config, and config["config"] is not great) Add explicit error handling for initializer calls that have neither the `config` nor the `transformers_config` argument, or have both.
This commit is contained in:
@@ -54,10 +54,20 @@ def keras_serializable(cls):
|
|||||||
raise AttributeError("Must set `config_class` to use @keras_serializable")
|
raise AttributeError("Must set `config_class` to use @keras_serializable")
|
||||||
|
|
||||||
@functools.wraps(initializer)
|
@functools.wraps(initializer)
|
||||||
def wrapped_init(self, config, *args, **kwargs):
|
def wrapped_init(self, *args, **kwargs):
|
||||||
if isinstance(config, dict):
|
transformers_config = kwargs.pop("transformers_config", None)
|
||||||
config = config_class.from_dict(config)
|
config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None)
|
||||||
|
if config is not None and transformers_config is not None:
|
||||||
|
raise ValueError("Must pass either `config` or `transformers_config`, not both")
|
||||||
|
elif config is not None:
|
||||||
|
# normal layer construction, call with unchanged args (config is already in there)
|
||||||
|
initializer(self, *args, **kwargs)
|
||||||
|
elif transformers_config is not None:
|
||||||
|
# Keras deserialization, convert dict to config
|
||||||
|
config = config_class.from_dict(transformers_config)
|
||||||
initializer(self, config, *args, **kwargs)
|
initializer(self, config, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
|
||||||
self._transformers_config = config
|
self._transformers_config = config
|
||||||
|
|
||||||
cls.__init__ = wrapped_init
|
cls.__init__ = wrapped_init
|
||||||
@@ -68,7 +78,7 @@ def keras_serializable(cls):
|
|||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
cfg = super(cls, self).get_config()
|
cfg = super(cls, self).get_config()
|
||||||
cfg["config"] = self._transformers_config.to_dict()
|
cfg["transformers_config"] = self._transformers_config.to_dict()
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
cls.get_config = get_config
|
cls.get_config = get_config
|
||||||
|
|||||||
Reference in New Issue
Block a user