From 6dce6dda1bfaf5c54c3957e24b92100bceb9b255 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 9 Oct 2019 11:57:55 +0200 Subject: [PATCH] fixing TF 2.0 model - adding more severe test on pt/tf equivalence --- transformers/tests/modeling_tf_common_test.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/transformers/tests/modeling_tf_common_test.py b/transformers/tests/modeling_tf_common_test.py index 483f031b16..4b363c6dc8 100644 --- a/transformers/tests/modeling_tf_common_test.py +++ b/transformers/tests/modeling_tf_common_test.py @@ -71,6 +71,8 @@ class TFCommonTestCases: if not is_torch_available(): return + import torch + import numpy as np import transformers config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -79,12 +81,22 @@ class TFCommonTestCases: pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining pt_model_class = getattr(transformers, pt_model_class_name) - tf_model = model_class(config) - pt_model = pt_model_class(config) + tf_model = model_class(config, output_hidden_states=True) + pt_model = pt_model_class(config, output_hidden_states=True) + # Check we can load pt model in tf and vice-versa (architecture similar) tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict) pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) + # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences + pt_model.eval() + pt_inputs_dict = dict((name, torch.from_numpy(key.numpy()).to(torch.long)) + for name, key in inputs_dict.items()) + with torch.no_grad(): + pto = pt_model(**pt_inputs_dict) + tfo = tf_model(inputs_dict) + max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy())) + self.assertLessEqual(max_diff, 2e-2) def test_keyword_and_dict_args(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()