From 099358675899f759110ad8ccecc22c2fab9b1888 Mon Sep 17 00:00:00 2001 From: JulianPani Date: Mon, 14 Oct 2019 02:09:53 +0300 Subject: [PATCH 1/2] remove usage of DUMMY_INPUTS Hey @thomwolf This change https://github.com/huggingface/transformers/commit/da26bae61b8c1e741fdc6735d46c61b43f649561#diff-8ddce309e88e8eb5b4d02228fd8881daL28-L29 removed the constant, but one usage of that constant remains in the code. --- transformers/modeling_tf_pytorch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers/modeling_tf_pytorch_utils.py b/transformers/modeling_tf_pytorch_utils.py index 5a70d9a72b..88ce4d4610 100644 --- a/transformers/modeling_tf_pytorch_utils.py +++ b/transformers/modeling_tf_pytorch_utils.py @@ -198,7 +198,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs tf_model = tf_model_class(pt_model.config) if tf_inputs is None: - tf_inputs = tf.constant(DUMMY_INPUTS) + tf_inputs = tf_model.dummy_inputs if tf_inputs is not None: tfo = tf_model(tf_inputs, training=False) # Make sure model is built From 898ce064f8c53b8744c51358d49eff51af0a8713 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 15 Oct 2019 10:04:19 +0200 Subject: [PATCH 2/2] add tests on TF2.0 & PT checkpoint => model convertion functions --- transformers/tests/modeling_tf_common_test.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/transformers/tests/modeling_tf_common_test.py b/transformers/tests/modeling_tf_common_test.py index 360f86ea69..f636c42889 100644 --- a/transformers/tests/modeling_tf_common_test.py +++ b/transformers/tests/modeling_tf_common_test.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import absolute_import, division, print_function +import os import copy import json import logging @@ -118,7 +119,7 @@ class TFCommonTestCases: tf_model = model_class(config) pt_model = pt_model_class(config) - # Check we can load pt model in tf and vice-versa (architecture similar) + # Check we can load pt model in tf and vice-versa with model => model functions tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict) pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) @@ -132,6 +133,26 @@ class TFCommonTestCases: max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy())) self.assertLessEqual(max_diff, 2e-2) + # Check we can load pt model in tf and vice-versa with checkpoint => model functions + with TemporaryDirectory() as tmpdirname: + pt_checkpoint_path = os.path.join(tmpdirname, 'pt_model.bin') + torch.save(pt_model.state_dict(), pt_checkpoint_path) + tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) + + tf_checkpoint_path = os.path.join(tmpdirname, 'tf_model.h5') + tf_model.save_weights(tf_checkpoint_path) + pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) + + # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences + pt_model.eval() + pt_inputs_dict = dict((name, torch.from_numpy(key.numpy()).to(torch.long)) + for name, key in inputs_dict.items()) + with torch.no_grad(): + pto = pt_model(**pt_inputs_dict) + tfo = tf_model(inputs_dict) + max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy())) + self.assertLessEqual(max_diff, 2e-2) + def test_compile_tf_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()