From 4be01e5cbff5c449f6ef4c0cbb79ccd625dec156 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Thu, 5 Mar 2020 11:08:45 +0000 Subject: [PATCH] Use name transformers_config in Keras serialization Be explicit that this is config for the transformers package (as these layers may coexist with other custom stuff in a Keras model, plus the Keras container itself is called config, and config["config"] is not great) Add explicit error handling for initializer calls that have neither the `config` nor the `transformers_config` argument, or have both. --- src/transformers/modeling_tf_utils.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 29049f7371..accaed00ae 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -54,10 +54,20 @@ def keras_serializable(cls): raise AttributeError("Must set `config_class` to use @keras_serializable") @functools.wraps(initializer) - def wrapped_init(self, config, *args, **kwargs): - if isinstance(config, dict): - config = config_class.from_dict(config) - initializer(self, config, *args, **kwargs) + def wrapped_init(self, *args, **kwargs): + transformers_config = kwargs.pop("transformers_config", None) + config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None) + if config is not None and transformers_config is not None: + raise ValueError("Must pass either `config` or `transformers_config`, not both") + elif config is not None: + # normal layer construction, call with unchanged args (config is already in there) + initializer(self, *args, **kwargs) + elif transformers_config is not None: + # Keras deserialization, convert dict to config + config = config_class.from_dict(transformers_config) + initializer(self, config, *args, **kwargs) + else: + raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)") self._transformers_config = config cls.__init__ = wrapped_init @@ -68,7 +78,7 @@ def keras_serializable(cls): def get_config(self): cfg = super(cls, self).get_config() - cfg["config"] = self._transformers_config.to_dict() + cfg["transformers_config"] = self._transformers_config.to_dict() return cfg cls.get_config = get_config