Support keras JSON/HDF5 serialization of main layers

Fixes #3101
This commit is contained in:
Gunnlaugur Thor Briem
2020-03-03 14:00:30 +00:00
parent a088d75e51
commit ba28170717
11 changed files with 59 additions and 26 deletions

View File

@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
logger = logging.getLogger(__name__)
@@ -164,9 +164,9 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return outputs
class TFCTRLMainLayer(tf.keras.layers.Layer):
class TFCTRLMainLayer(TFMainLayer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
super().__init__(config, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past