Merge pull request #3103 from gthb/keras-serialization

Support keras JSON/HDF5 serialization of main layers
This commit is contained in:
Thomas Wolf
2020-03-06 12:59:13 +01:00
committed by GitHub
8 changed files with 134 additions and 14 deletions

View File

@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_albert import AlbertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__)
@@ -478,7 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
return hidden_states
@keras_serializable
class TFAlbertMainLayer(tf.keras.layers.Layer):
config_class = AlbertConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers

View File

@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_bert import BertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__)
@@ -471,7 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer):
return seq_relationship_score
@keras_serializable
class TFBertMainLayer(tf.keras.layers.Layer):
config_class = BertConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers

View File

@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
logger = logging.getLogger(__name__)
@@ -164,7 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return outputs
@keras_serializable
class TFCTRLMainLayer(tf.keras.layers.Layer):
config_class = CTRLConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states

View File

@@ -29,6 +29,7 @@ from .modeling_tf_utils import (
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
keras_serializable,
shape_list,
)
@@ -196,7 +197,10 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, present, (attentions)
@keras_serializable
class TFGPT2MainLayer(tf.keras.layers.Layer):
config_class = GPT2Config
def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states

View File

@@ -24,7 +24,7 @@ import tensorflow as tf
from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__)
@@ -378,7 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
return embed
@keras_serializable
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
config_class = TransfoXLConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions

View File

@@ -14,8 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
import functools
import logging
import os
@@ -47,6 +46,64 @@ class TFModelUtilsMixin:
return self.count_params()
def keras_serializable(cls):
"""
Decorate a Keras Layer class to support Keras serialization.
This is done by:
1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
serialization time
2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
convert it to a config object for the actual layer initializer
3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does
not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model`
:param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a
`TF*MainLayer` class in this project)
:return: the same class object, with modifications for Keras deserialization.
"""
initializer = cls.__init__
config_class = getattr(cls, "config_class", None)
if config_class is None:
raise AttributeError("Must set `config_class` to use @keras_serializable")
@functools.wraps(initializer)
def wrapped_init(self, *args, **kwargs):
transformers_config = kwargs.pop("transformers_config", None)
config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None)
if config is not None and transformers_config is not None:
raise ValueError("Must pass either `config` or `transformers_config`, not both")
elif config is not None:
# normal layer construction, call with unchanged args (config is already in there)
initializer(self, *args, **kwargs)
elif transformers_config is not None:
# Keras deserialization, convert dict to config
config = config_class.from_dict(transformers_config)
initializer(self, config, *args, **kwargs)
else:
raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
self._transformers_config = config
cls.__init__ = wrapped_init
if not hasattr(cls, "get_config"):
raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
if hasattr(cls.get_config, "_is_default"):
def get_config(self):
cfg = super(cls, self).get_config()
cfg["transformers_config"] = self._transformers_config.to_dict()
return cfg
cls.get_config = get_config
cls._keras_serializable = True
if hasattr(tf.keras.utils, "register_keras_serializable"):
cls = tf.keras.utils.register_keras_serializable()(cls)
return cls
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
r""" Base class for all TF models.

View File

@@ -24,7 +24,14 @@ import tensorflow as tf
from .configuration_xlnet import XLNetConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
from .modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
keras_serializable,
shape_list,
)
logger = logging.getLogger(__name__)
@@ -342,7 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
return hidden_states
@keras_serializable
class TFXLNetMainLayer(tf.keras.layers.Layer):
config_class = XLNetConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions