Merge pull request #3103 from gthb/keras-serialization

Support keras JSON/HDF5 serialization of main layers
This commit is contained in:
Thomas Wolf
2020-03-06 12:59:13 +01:00
committed by GitHub
8 changed files with 134 additions and 14 deletions

View File

@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
logger = logging.getLogger(__name__)
@@ -164,7 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return outputs
@keras_serializable
class TFCTRLMainLayer(tf.keras.layers.Layer):
config_class = CTRLConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states