Merge pull request #3103 from gthb/keras-serialization
Support keras JSON/HDF5 serialization of main layers
This commit is contained in:
@@ -23,7 +23,7 @@ import tensorflow as tf
|
||||
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -164,7 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
return outputs
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
config_class = CTRLConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
Reference in New Issue
Block a user