diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index c58243965a..64cd93863a 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer): return outputs # last-layer hidden state, (all hidden states), (all attentions) -@keras_serializable class TFDistilBertMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index 954091a5bd..50e0af3e6c 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, (attentions) -@keras_serializable class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): def __init__(self, config, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index b0f56edcfa..f233074895 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer): # The full model without a specific pretrained or finetuning head is # provided as a tf.keras.layers.Layer usually called "TFT5MainLayer" #################################################### -@keras_serializable class TFT5MainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index c024aa8704..e10afb37a8 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -71,6 +71,7 @@ def keras_serializable(cls): cls.get_config = get_config + cls._keras_serializable = True return tf.keras.utils.register_keras_serializable()(cls) diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 2b7f49d7f0..8a3f200e6f 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer): return x -@keras_serializable class TFXLMMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 6b4fb24f5b..8f4cae7491 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -103,11 +103,9 @@ class TFModelTesterMixin: if module_member_name.endswith("MainLayer") for module_member in (getattr(module, module_member_name),) if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__ + and getattr(module_member, '_keras_serializable', False) ) for main_layer_class in tf_main_layer_classes: - if main_layer_class.__name__ == "TFT5MainLayer": - # Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly - continue main_layer = main_layer_class(config) symbolic_inputs = { name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()