From 4c91a3af941627297c4b81fb48be1c9cb5ae7ee0 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Thu, 5 Mar 2020 11:48:10 +0000 Subject: [PATCH] Document keras_serializable decorator --- src/transformers/modeling_tf_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index accaed00ae..4fc00ff01a 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -47,6 +47,21 @@ class TFModelUtilsMixin: def keras_serializable(cls): + """ + Decorate a Keras Layer class to support Keras serialization. + + This is done by: + 1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at + serialization time + 2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and + convert it to a config object for the actual layer initializer + 3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does + not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model` + + :param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a + `TF*MainLayer` class in this project) + :return: the same class object, with modifications for Keras deserialization. + """ initializer = cls.__init__ config_class = getattr(cls, "config_class", None)