Use class decorator instead of superclass

When supplied by Keras deserialization, the config parameter to initializers
will be a dict. So intercept it and convert to PretrainedConfig object (and
store in instance attribute for get_config to get at it) before passing to the
actual initializer. To accomplish this, and repeat as little code as possible,
use a class decorator on TF*MainLayer classes.
This commit is contained in:
Gunnlaugur Thor Briem
2020-03-03 22:31:38 +00:00
parent b8da16f390
commit 0c716ede8c
14 changed files with 94 additions and 46 deletions

View File

@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_bert import BertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
logger = logging.getLogger(__name__)
@@ -471,9 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer):
return seq_relationship_score
class TFBertMainLayer(TFMainLayer):
@keras_serializable
class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.embeddings = TFBertEmbeddings(config, name="embeddings")