From 324f361e914e2083cf1627a7cdba05936fe538b6 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Tue, 22 Sep 2020 15:31:13 +0200 Subject: [PATCH] Fix saving TF custom models (#7291) * Fix #7277 * Apply style * Add a full training pipeline test * Apply style --- src/transformers/modeling_tf_utils.py | 26 +++++------ tests/test_modeling_tf_common.py | 63 +++++++++++++++++++++++++++ tests/test_modeling_tf_funnel.py | 2 +- 3 files changed, 77 insertions(+), 14 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index afccec7bc5..80fe5f91c7 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -85,20 +85,20 @@ def keras_serializable(cls): @functools.wraps(initializer) def wrapped_init(self, *args, **kwargs): - transformers_config = kwargs.pop("transformers_config", None) - config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None) - if config is not None and transformers_config is not None: - raise ValueError("Must pass either `config` or `transformers_config`, not both") - elif config is not None: - # normal layer construction, call with unchanged args (config is already in there) - initializer(self, *args, **kwargs) - elif transformers_config is not None: - # Keras deserialization, convert dict to config - config = config_class.from_dict(transformers_config) + config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None) + + if isinstance(config, dict): + config = config_class.from_dict(config) initializer(self, config, *args, **kwargs) + elif isinstance(config, PretrainedConfig): + if len(args) > 0: + initializer(self, *args, **kwargs) + else: + initializer(self, config, *args, **kwargs) else: - raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)") - self._transformers_config = config + raise ValueError("Must pass either `config` (PretrainedConfig) or `config` (dict)") + + self._config = config self._kwargs = kwargs cls.__init__ = wrapped_init @@ -109,7 +109,7 @@ def keras_serializable(cls): def get_config(self): cfg = super(cls, self).get_config() - cfg["transformers_config"] = self._transformers_config.to_dict() + cfg["config"] = self._config.to_dict() cfg.update(self._kwargs) return cfg diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index b251b890c4..1ce67b1be7 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -354,6 +354,69 @@ class TFModelTesterMixin: max_diff = np.amax(np.abs(tfo - pto)) self.assertLessEqual(max_diff, 4e-2) + def test_train_pipeline_custom_model(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + tf_main_layer_classes = set( + module_member + for model_class in self.all_model_classes + for module in (import_module(model_class.__module__),) + for module_member_name in dir(module) + if module_member_name.endswith("MainLayer") + for module_member in (getattr(module, module_member_name),) + if isinstance(module_member, type) + and tf.keras.layers.Layer in module_member.__bases__ + and getattr(module_member, "_keras_serializable", False) + ) + + for main_layer_class in tf_main_layer_classes: + # T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter + if "T5" in main_layer_class.__name__: + # Take the same values than in TFT5ModelTester for this shared layer + shared = TFSharedEmbeddings(self.model_tester.vocab_size, self.model_tester.hidden_size, name="shared") + config.use_cache = False + main_layer = main_layer_class(config, embed_tokens=shared) + del inputs_dict["use_cache"] + else: + main_layer = main_layer_class(config) + + symbolic_inputs = { + name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() + } + + if hasattr(self.model_tester, "num_labels"): + num_labels = self.model_tester.num_labels + else: + num_labels = 2 + + X = tf.data.Dataset.from_tensor_slices( + (inputs_dict, np.random.randint(0, num_labels, (self.model_tester.batch_size, 1))) + ).batch(1) + + hidden_states = main_layer(symbolic_inputs)[0] + outputs = tf.keras.layers.Dense(num_labels, activation="softmax", name="outputs")(hidden_states) + model = tf.keras.models.Model(inputs=symbolic_inputs, outputs=[outputs]) + + model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["acc"]) + model.fit(X, epochs=1) + + with tempfile.TemporaryDirectory() as tmpdirname: + filepath = os.path.join(tmpdirname, "keras_model.h5") + model.save(filepath) + if "T5" in main_layer_class.__name__: + model = tf.keras.models.load_model( + filepath, + custom_objects={ + main_layer_class.__name__: main_layer_class, + "TFSharedEmbeddings": TFSharedEmbeddings, + }, + ) + else: + model = tf.keras.models.load_model( + filepath, custom_objects={main_layer_class.__name__: main_layer_class} + ) + assert isinstance(model, tf.keras.Model) + model(inputs_dict) + def test_compile_tf_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_funnel.py b/tests/test_modeling_tf_funnel.py index 12567e93fb..bb723c8d5b 100644 --- a/tests/test_modeling_tf_funnel.py +++ b/tests/test_modeling_tf_funnel.py @@ -327,7 +327,7 @@ class TFFunnelModelTester: @require_tf -class FunnelModelTest(TFModelTesterMixin, unittest.TestCase): +class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = ( ( TFFunnelModel,