diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 75f785ede0..78b0077d8e 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_albert import AlbertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_bert import ACT2FN, TFBertSelfAttention -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -478,7 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): return hidden_states +@keras_serializable class TFAlbertMainLayer(tf.keras.layers.Layer): + config_class = AlbertConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 1904623581..11ae8da6b8 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_bert import BertConfig from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -471,7 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer): return seq_relationship_score +@keras_serializable class TFBertMainLayer(tf.keras.layers.Layer): + config_class = BertConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index 8a049bbce9..2b355c20c5 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_ctrl import CTRLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -164,7 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer): return outputs +@keras_serializable class TFCTRLMainLayer(tf.keras.layers.Layer): + config_class = CTRLConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 3b79d58949..9111f7a923 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -29,6 +29,7 @@ from .modeling_tf_utils import ( TFSequenceSummary, TFSharedEmbeddings, get_initializer, + keras_serializable, shape_list, ) @@ -196,7 +197,10 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, present, (attentions) +@keras_serializable class TFGPT2MainLayer(tf.keras.layers.Layer): + config_class = GPT2Config + def __init__(self, config, *inputs, **kwargs): super().__init__(*inputs, **kwargs) self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index 098a4c9143..53c4dc7bef 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -24,7 +24,7 @@ import tensorflow as tf from .configuration_transfo_xl import TransfoXLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -378,7 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): return embed +@keras_serializable class TFTransfoXLMainLayer(tf.keras.layers.Layer): + config_class = TransfoXLConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.output_attentions = config.output_attentions diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 68151d93c5..1eb36ba539 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -14,8 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """TF general model utils.""" - - +import functools import logging import os @@ -47,6 +46,64 @@ class TFModelUtilsMixin: return self.count_params() +def keras_serializable(cls): + """ + Decorate a Keras Layer class to support Keras serialization. + + This is done by: + 1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at + serialization time + 2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and + convert it to a config object for the actual layer initializer + 3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does + not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model` + + :param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a + `TF*MainLayer` class in this project) + :return: the same class object, with modifications for Keras deserialization. + """ + initializer = cls.__init__ + + config_class = getattr(cls, "config_class", None) + if config_class is None: + raise AttributeError("Must set `config_class` to use @keras_serializable") + + @functools.wraps(initializer) + def wrapped_init(self, *args, **kwargs): + transformers_config = kwargs.pop("transformers_config", None) + config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None) + if config is not None and transformers_config is not None: + raise ValueError("Must pass either `config` or `transformers_config`, not both") + elif config is not None: + # normal layer construction, call with unchanged args (config is already in there) + initializer(self, *args, **kwargs) + elif transformers_config is not None: + # Keras deserialization, convert dict to config + config = config_class.from_dict(transformers_config) + initializer(self, config, *args, **kwargs) + else: + raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)") + self._transformers_config = config + + cls.__init__ = wrapped_init + + if not hasattr(cls, "get_config"): + raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses") + if hasattr(cls.get_config, "_is_default"): + + def get_config(self): + cfg = super(cls, self).get_config() + cfg["transformers_config"] = self._transformers_config.to_dict() + return cfg + + cls.get_config = get_config + + cls._keras_serializable = True + if hasattr(tf.keras.utils, "register_keras_serializable"): + cls = tf.keras.utils.register_keras_serializable()(cls) + return cls + + class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): r""" Base class for all TF models. diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index 87ebe16858..8797a22194 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -24,7 +24,14 @@ import tensorflow as tf from .configuration_xlnet import XLNetConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list +from .modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceSummary, + TFSharedEmbeddings, + get_initializer, + keras_serializable, + shape_list, +) logger = logging.getLogger(__name__) @@ -342,7 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer): return hidden_states +@keras_serializable class TFXLNetMainLayer(tf.keras.layers.Layer): + config_class = XLNetConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.output_attentions = config.output_attentions diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index a6d2e8e32f..6887388d80 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -19,6 +19,7 @@ import os import random import tempfile import unittest +from importlib import import_module from transformers import is_tf_available, is_torch_available @@ -89,13 +90,49 @@ class TFModelTesterMixin: model = model_class.from_pretrained(tmpdirname) after_outputs = model(inputs_dict) - # Make sure we don't have nans - out_1 = after_outputs[0].numpy() - out_2 = outputs[0].numpy() - out_1 = out_1[~np.isnan(out_1)] - out_2 = out_2[~np.isnan(out_2)] - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) + self.assert_outputs_same(after_outputs, outputs) + + def test_keras_save_load(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + tf_main_layer_classes = set( + module_member + for model_class in self.all_model_classes + for module in (import_module(model_class.__module__),) + for module_member_name in dir(module) + if module_member_name.endswith("MainLayer") + for module_member in (getattr(module, module_member_name),) + if isinstance(module_member, type) + and tf.keras.layers.Layer in module_member.__bases__ + and getattr(module_member, "_keras_serializable", False) + ) + for main_layer_class in tf_main_layer_classes: + main_layer = main_layer_class(config) + symbolic_inputs = { + name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() + } + model = tf.keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs)) + outputs = model(inputs_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + filepath = os.path.join(tmpdirname, "keras_model.h5") + model.save(filepath) + model = tf.keras.models.load_model( + filepath, custom_objects={main_layer_class.__name__: main_layer_class} + ) + assert isinstance(model, tf.keras.Model) + after_outputs = model(inputs_dict) + self.assert_outputs_same(after_outputs, outputs) + + def assert_outputs_same(self, after_outputs, outputs): + # Make sure we don't have nans + out_1 = after_outputs[0].numpy() + out_2 = outputs[0].numpy() + self.assertEqual(out_1.shape, out_2.shape) + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) def test_pt_tf_model_equivalence(self): if not is_torch_available():