From 18a3cef7d523c061d4c00b65bbc41810f94782f5 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 11 Oct 2019 16:09:42 +0200 Subject: [PATCH] no nans --- transformers/tests/modeling_common_test.py | 8 +++++++- transformers/tests/modeling_tf_common_test.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/transformers/tests/modeling_common_test.py b/transformers/tests/modeling_common_test.py index 298dcf3bdc..1c8b1584c7 100644 --- a/transformers/tests/modeling_common_test.py +++ b/transformers/tests/modeling_common_test.py @@ -88,7 +88,13 @@ class CommonTestCases: model = model_class.from_pretrained(tmpdirname) with torch.no_grad(): after_outputs = model(**inputs_dict) - max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy())) + + # Make sure we don't have nans + out_1 = after_outputs[0].numpy() + out_2 = outputs[0].numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) def test_initialization(self): diff --git a/transformers/tests/modeling_tf_common_test.py b/transformers/tests/modeling_tf_common_test.py index 5b21ad15d7..360f86ea69 100644 --- a/transformers/tests/modeling_tf_common_test.py +++ b/transformers/tests/modeling_tf_common_test.py @@ -92,7 +92,13 @@ class TFCommonTestCases: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname) after_outputs = model(inputs_dict) - max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy())) + + # Make sure we don't have nans + out_1 = after_outputs[0].numpy() + out_2 = outputs[0].numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) def test_pt_tf_model_equivalence(self):