fix: work with Tensorflow < 2.1.0
tf.keras.utils.register_keras_serializable was added in TF 2.1.0, so don't rely on it being there; just decorate the class with it if it exists.
This commit is contained in:
@@ -72,7 +72,9 @@ def keras_serializable(cls):
|
|||||||
cls.get_config = get_config
|
cls.get_config = get_config
|
||||||
|
|
||||||
cls._keras_serializable = True
|
cls._keras_serializable = True
|
||||||
return tf.keras.utils.register_keras_serializable()(cls)
|
if hasattr(tf.keras.utils, "register_keras_serializable"):
|
||||||
|
cls = tf.keras.utils.register_keras_serializable()(cls)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||||
|
|||||||
Reference in New Issue
Block a user