Put @keras_serializable only on layers it works on

And only run the test on TF*MainLayer classes so marked.
This commit is contained in:
Gunnlaugur Thor Briem
2020-03-03 22:44:38 +00:00
parent 0c716ede8c
commit 470753bcf5
6 changed files with 2 additions and 7 deletions

View File

@@ -103,11 +103,9 @@ class TFModelTesterMixin:
if module_member_name.endswith("MainLayer")
for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, '_keras_serializable', False)
)
for main_layer_class in tf_main_layer_classes:
if main_layer_class.__name__ == "TFT5MainLayer":
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
continue
main_layer = main_layer_class(config)
symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()