Put @keras_serializable only on layers it works on
And only run the test on TF*MainLayer classes so marked.
This commit is contained in:
@@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer):
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer):
|
||||
return outputs # x, (attentions)
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer):
|
||||
# The full model without a specific pretrained or finetuning head is
|
||||
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
||||
####################################################
|
||||
@keras_serializable
|
||||
class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -71,6 +71,7 @@ def keras_serializable(cls):
|
||||
|
||||
cls.get_config = get_config
|
||||
|
||||
cls._keras_serializable = True
|
||||
return tf.keras.utils.register_keras_serializable()(cls)
|
||||
|
||||
|
||||
|
||||
@@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer):
|
||||
return x
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user