Merge pull request #3103 from gthb/keras-serialization
Support keras JSON/HDF5 serialization of main layers
This commit is contained in:
@@ -14,8 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TF general model utils."""
|
||||
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -47,6 +46,64 @@ class TFModelUtilsMixin:
|
||||
return self.count_params()
|
||||
|
||||
|
||||
def keras_serializable(cls):
|
||||
"""
|
||||
Decorate a Keras Layer class to support Keras serialization.
|
||||
|
||||
This is done by:
|
||||
1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
|
||||
serialization time
|
||||
2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
|
||||
convert it to a config object for the actual layer initializer
|
||||
3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does
|
||||
not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model`
|
||||
|
||||
:param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a
|
||||
`TF*MainLayer` class in this project)
|
||||
:return: the same class object, with modifications for Keras deserialization.
|
||||
"""
|
||||
initializer = cls.__init__
|
||||
|
||||
config_class = getattr(cls, "config_class", None)
|
||||
if config_class is None:
|
||||
raise AttributeError("Must set `config_class` to use @keras_serializable")
|
||||
|
||||
@functools.wraps(initializer)
|
||||
def wrapped_init(self, *args, **kwargs):
|
||||
transformers_config = kwargs.pop("transformers_config", None)
|
||||
config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None)
|
||||
if config is not None and transformers_config is not None:
|
||||
raise ValueError("Must pass either `config` or `transformers_config`, not both")
|
||||
elif config is not None:
|
||||
# normal layer construction, call with unchanged args (config is already in there)
|
||||
initializer(self, *args, **kwargs)
|
||||
elif transformers_config is not None:
|
||||
# Keras deserialization, convert dict to config
|
||||
config = config_class.from_dict(transformers_config)
|
||||
initializer(self, config, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
|
||||
self._transformers_config = config
|
||||
|
||||
cls.__init__ = wrapped_init
|
||||
|
||||
if not hasattr(cls, "get_config"):
|
||||
raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
|
||||
if hasattr(cls.get_config, "_is_default"):
|
||||
|
||||
def get_config(self):
|
||||
cfg = super(cls, self).get_config()
|
||||
cfg["transformers_config"] = self._transformers_config.to_dict()
|
||||
return cfg
|
||||
|
||||
cls.get_config = get_config
|
||||
|
||||
cls._keras_serializable = True
|
||||
if hasattr(tf.keras.utils, "register_keras_serializable"):
|
||||
cls = tf.keras.utils.register_keras_serializable()(cls)
|
||||
return cls
|
||||
|
||||
|
||||
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
r""" Base class for all TF models.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user