Put @keras_serializable only on layers it works on

And only run the test on TF*MainLayer classes so marked.
This commit is contained in:
Gunnlaugur Thor Briem
2020-03-03 22:44:38 +00:00
parent 0c716ede8c
commit 470753bcf5
6 changed files with 2 additions and 7 deletions

View File

@@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
@keras_serializable
class TFDistilBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

View File

@@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, (attentions)
@keras_serializable
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)

View File

@@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer):
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
####################################################
@keras_serializable
class TFT5MainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)

View File

@@ -71,6 +71,7 @@ def keras_serializable(cls):
cls.get_config = get_config
cls._keras_serializable = True
return tf.keras.utils.register_keras_serializable()(cls)

View File

@@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer):
return x
@keras_serializable
class TFXLMMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)