Support keras JSON/HDF5 serialization of main layers

Fixes #3101
This commit is contained in:
Gunnlaugur Thor Briem
2020-03-03 14:00:30 +00:00
parent a088d75e51
commit ba28170717
11 changed files with 59 additions and 26 deletions

View File

@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_bert import BertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
logger = logging.getLogger(__name__)
@@ -471,9 +471,9 @@ class TFBertNSPHead(tf.keras.layers.Layer):
return seq_relationship_score
class TFBertMainLayer(tf.keras.layers.Layer):
class TFBertMainLayer(TFMainLayer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
super().__init__(config, **kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFBertEmbeddings(config, name="embeddings")