fix: work with Tensorflow < 2.1.0

tf.keras.utils.register_keras_serializable was added in TF 2.1.0, so
don't rely on it being there; just decorate the class with it if it
exists.
This commit is contained in:
Gunnlaugur Thor Briem
2020-03-04 16:57:28 +00:00
parent 96c4990165
commit 18f4b9274f

View File

@@ -72,7 +72,9 @@ def keras_serializable(cls):
cls.get_config = get_config
cls._keras_serializable = True
return tf.keras.utils.register_keras_serializable()(cls)
if hasattr(tf.keras.utils, "register_keras_serializable"):
cls = tf.keras.utils.register_keras_serializable()(cls)
return cls
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):