Use class decorator instead of superclass

When supplied by Keras deserialization, the config parameter to initializers
will be a dict. So intercept it and convert to PretrainedConfig object (and
store in instance attribute for get_config to get at it) before passing to the
actual initializer. To accomplish this, and repeat as little code as possible,
use a class decorator on TF*MainLayer classes.
This commit is contained in:
Gunnlaugur Thor Briem
2020-03-03 22:31:38 +00:00
parent b8da16f390
commit 0c716ede8c
14 changed files with 94 additions and 46 deletions

View File

@@ -23,7 +23,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
logger = logging.getLogger(__name__)
@@ -164,9 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return outputs
class TFCTRLMainLayer(TFMainLayer):
@keras_serializable
class TFCTRLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past