Put @keras_serializable only on layers it works on
And only run the test on TF*MainLayer classes so marked.
This commit is contained in:
@@ -103,11 +103,9 @@ class TFModelTesterMixin:
|
||||
if module_member_name.endswith("MainLayer")
|
||||
for module_member in (getattr(module, module_member_name),)
|
||||
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
|
||||
and getattr(module_member, '_keras_serializable', False)
|
||||
)
|
||||
for main_layer_class in tf_main_layer_classes:
|
||||
if main_layer_class.__name__ == "TFT5MainLayer":
|
||||
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
|
||||
continue
|
||||
main_layer = main_layer_class(config)
|
||||
symbolic_inputs = {
|
||||
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||
|
||||
Reference in New Issue
Block a user