Put @keras_serializable only on layers it works on
And only run the test on TF*MainLayer classes so marked.
This commit is contained in:
@@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer):
|
|||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
@keras_serializable
|
|
||||||
class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|||||||
@@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
return outputs # x, (attentions)
|
return outputs # x, (attentions)
|
||||||
|
|
||||||
|
|
||||||
@keras_serializable
|
|
||||||
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
|
|||||||
@@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer):
|
|||||||
# The full model without a specific pretrained or finetuning head is
|
# The full model without a specific pretrained or finetuning head is
|
||||||
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
||||||
####################################################
|
####################################################
|
||||||
@keras_serializable
|
|
||||||
class TFT5MainLayer(tf.keras.layers.Layer):
|
class TFT5MainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ def keras_serializable(cls):
|
|||||||
|
|
||||||
cls.get_config = get_config
|
cls.get_config = get_config
|
||||||
|
|
||||||
|
cls._keras_serializable = True
|
||||||
return tf.keras.utils.register_keras_serializable()(cls)
|
return tf.keras.utils.register_keras_serializable()(cls)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@keras_serializable
|
|
||||||
class TFXLMMainLayer(tf.keras.layers.Layer):
|
class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|||||||
@@ -103,11 +103,9 @@ class TFModelTesterMixin:
|
|||||||
if module_member_name.endswith("MainLayer")
|
if module_member_name.endswith("MainLayer")
|
||||||
for module_member in (getattr(module, module_member_name),)
|
for module_member in (getattr(module, module_member_name),)
|
||||||
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
|
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
|
||||||
|
and getattr(module_member, '_keras_serializable', False)
|
||||||
)
|
)
|
||||||
for main_layer_class in tf_main_layer_classes:
|
for main_layer_class in tf_main_layer_classes:
|
||||||
if main_layer_class.__name__ == "TFT5MainLayer":
|
|
||||||
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
|
|
||||||
continue
|
|
||||||
main_layer = main_layer_class(config)
|
main_layer = main_layer_class(config)
|
||||||
symbolic_inputs = {
|
symbolic_inputs = {
|
||||||
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||||
|
|||||||
Reference in New Issue
Block a user