Put @keras_serializable only on layers it works on
And only run the test on TF*MainLayer classes so marked.
This commit is contained in:
@@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer):
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer):
|
||||
return outputs # x, (attentions)
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer):
|
||||
# The full model without a specific pretrained or finetuning head is
|
||||
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
||||
####################################################
|
||||
@keras_serializable
|
||||
class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -71,6 +71,7 @@ def keras_serializable(cls):
|
||||
|
||||
cls.get_config = get_config
|
||||
|
||||
cls._keras_serializable = True
|
||||
return tf.keras.utils.register_keras_serializable()(cls)
|
||||
|
||||
|
||||
|
||||
@@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer):
|
||||
return x
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -103,11 +103,9 @@ class TFModelTesterMixin:
|
||||
if module_member_name.endswith("MainLayer")
|
||||
for module_member in (getattr(module, module_member_name),)
|
||||
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
|
||||
and getattr(module_member, '_keras_serializable', False)
|
||||
)
|
||||
for main_layer_class in tf_main_layer_classes:
|
||||
if main_layer_class.__name__ == "TFT5MainLayer":
|
||||
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
|
||||
continue
|
||||
main_layer = main_layer_class(config)
|
||||
symbolic_inputs = {
|
||||
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||
|
||||
Reference in New Issue
Block a user