Merge pull request #3103 from gthb/keras-serialization
Support keras JSON/HDF5 serialization of main layers
This commit is contained in:
@@ -23,7 +23,7 @@ import tensorflow as tf
|
|||||||
from .configuration_albert import AlbertConfig
|
from .configuration_albert import AlbertConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -478,7 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFAlbertMainLayer(tf.keras.layers.Layer):
|
class TFAlbertMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = AlbertConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -471,7 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
|||||||
return seq_relationship_score
|
return seq_relationship_score
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFBertMainLayer(tf.keras.layers.Layer):
|
class TFBertMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = BertConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_ctrl import CTRLConfig
|
from .configuration_ctrl import CTRLConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -164,7 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFCTRLMainLayer(tf.keras.layers.Layer):
|
class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = CTRLConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from .modeling_tf_utils import (
|
|||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
get_initializer,
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -196,7 +197,10 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
return outputs # x, present, (attentions)
|
return outputs # x, present, (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = GPT2Config
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import tensorflow as tf
|
|||||||
from .configuration_transfo_xl import TransfoXLConfig
|
from .configuration_transfo_xl import TransfoXLConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -378,7 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = TransfoXLConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
|
|||||||
@@ -14,8 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""TF general model utils."""
|
"""TF general model utils."""
|
||||||
|
import functools
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -47,6 +46,64 @@ class TFModelUtilsMixin:
|
|||||||
return self.count_params()
|
return self.count_params()
|
||||||
|
|
||||||
|
|
||||||
|
def keras_serializable(cls):
|
||||||
|
"""
|
||||||
|
Decorate a Keras Layer class to support Keras serialization.
|
||||||
|
|
||||||
|
This is done by:
|
||||||
|
1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
|
||||||
|
serialization time
|
||||||
|
2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
|
||||||
|
convert it to a config object for the actual layer initializer
|
||||||
|
3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does
|
||||||
|
not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model`
|
||||||
|
|
||||||
|
:param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a
|
||||||
|
`TF*MainLayer` class in this project)
|
||||||
|
:return: the same class object, with modifications for Keras deserialization.
|
||||||
|
"""
|
||||||
|
initializer = cls.__init__
|
||||||
|
|
||||||
|
config_class = getattr(cls, "config_class", None)
|
||||||
|
if config_class is None:
|
||||||
|
raise AttributeError("Must set `config_class` to use @keras_serializable")
|
||||||
|
|
||||||
|
@functools.wraps(initializer)
|
||||||
|
def wrapped_init(self, *args, **kwargs):
|
||||||
|
transformers_config = kwargs.pop("transformers_config", None)
|
||||||
|
config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None)
|
||||||
|
if config is not None and transformers_config is not None:
|
||||||
|
raise ValueError("Must pass either `config` or `transformers_config`, not both")
|
||||||
|
elif config is not None:
|
||||||
|
# normal layer construction, call with unchanged args (config is already in there)
|
||||||
|
initializer(self, *args, **kwargs)
|
||||||
|
elif transformers_config is not None:
|
||||||
|
# Keras deserialization, convert dict to config
|
||||||
|
config = config_class.from_dict(transformers_config)
|
||||||
|
initializer(self, config, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
|
||||||
|
self._transformers_config = config
|
||||||
|
|
||||||
|
cls.__init__ = wrapped_init
|
||||||
|
|
||||||
|
if not hasattr(cls, "get_config"):
|
||||||
|
raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
|
||||||
|
if hasattr(cls.get_config, "_is_default"):
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
cfg = super(cls, self).get_config()
|
||||||
|
cfg["transformers_config"] = self._transformers_config.to_dict()
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
cls.get_config = get_config
|
||||||
|
|
||||||
|
cls._keras_serializable = True
|
||||||
|
if hasattr(tf.keras.utils, "register_keras_serializable"):
|
||||||
|
cls = tf.keras.utils.register_keras_serializable()(cls)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||||
r""" Base class for all TF models.
|
r""" Base class for all TF models.
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,14 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_xlnet import XLNetConfig
|
from .configuration_xlnet import XLNetConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
|
from .modeling_tf_utils import (
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFSequenceSummary,
|
||||||
|
TFSharedEmbeddings,
|
||||||
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -342,7 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFXLNetMainLayer(tf.keras.layers.Layer):
|
class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = XLNetConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
from transformers import is_tf_available, is_torch_available
|
from transformers import is_tf_available, is_torch_available
|
||||||
|
|
||||||
@@ -89,9 +90,45 @@ class TFModelTesterMixin:
|
|||||||
model = model_class.from_pretrained(tmpdirname)
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
after_outputs = model(inputs_dict)
|
after_outputs = model(inputs_dict)
|
||||||
|
|
||||||
|
self.assert_outputs_same(after_outputs, outputs)
|
||||||
|
|
||||||
|
def test_keras_save_load(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
tf_main_layer_classes = set(
|
||||||
|
module_member
|
||||||
|
for model_class in self.all_model_classes
|
||||||
|
for module in (import_module(model_class.__module__),)
|
||||||
|
for module_member_name in dir(module)
|
||||||
|
if module_member_name.endswith("MainLayer")
|
||||||
|
for module_member in (getattr(module, module_member_name),)
|
||||||
|
if isinstance(module_member, type)
|
||||||
|
and tf.keras.layers.Layer in module_member.__bases__
|
||||||
|
and getattr(module_member, "_keras_serializable", False)
|
||||||
|
)
|
||||||
|
for main_layer_class in tf_main_layer_classes:
|
||||||
|
main_layer = main_layer_class(config)
|
||||||
|
symbolic_inputs = {
|
||||||
|
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||||
|
}
|
||||||
|
model = tf.keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
|
||||||
|
outputs = model(inputs_dict)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
filepath = os.path.join(tmpdirname, "keras_model.h5")
|
||||||
|
model.save(filepath)
|
||||||
|
model = tf.keras.models.load_model(
|
||||||
|
filepath, custom_objects={main_layer_class.__name__: main_layer_class}
|
||||||
|
)
|
||||||
|
assert isinstance(model, tf.keras.Model)
|
||||||
|
after_outputs = model(inputs_dict)
|
||||||
|
self.assert_outputs_same(after_outputs, outputs)
|
||||||
|
|
||||||
|
def assert_outputs_same(self, after_outputs, outputs):
|
||||||
# Make sure we don't have nans
|
# Make sure we don't have nans
|
||||||
out_1 = after_outputs[0].numpy()
|
out_1 = after_outputs[0].numpy()
|
||||||
out_2 = outputs[0].numpy()
|
out_2 = outputs[0].numpy()
|
||||||
|
self.assertEqual(out_1.shape, out_2.shape)
|
||||||
out_1 = out_1[~np.isnan(out_1)]
|
out_1 = out_1[~np.isnan(out_1)]
|
||||||
out_2 = out_2[~np.isnan(out_2)]
|
out_2 = out_2[~np.isnan(out_2)]
|
||||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
|
|||||||
Reference in New Issue
Block a user