From ea2600bd5f1d36f2fb61958be21db5b901e33884 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 27 Jan 2020 21:57:23 -0500 Subject: [PATCH] Absolute definitive HeisenDistilBug solve cc @julien-c @thomwolf --- tests/test_modeling_tf_common.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 97bccd2cbe..bcfb6bfe5d 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -113,10 +113,13 @@ class TFModelTesterMixin: tf_hidden_states = tfo[0].numpy() pt_hidden_states = pto[0].numpy() - pt_hidden_states[np.isnan(tf_hidden_states)] = 0 - tf_hidden_states[np.isnan(tf_hidden_states)] = 0 - pt_hidden_states[np.isnan(pt_hidden_states)] = 0 - tf_hidden_states[np.isnan(pt_hidden_states)] = 0 + tf_nans = np.copy(np.isnan(tf_hidden_states)) + pt_nans = np.copy(np.isnan(pt_hidden_states)) + + pt_hidden_states[tf_nans] = 0 + tf_hidden_states[tf_nans] = 0 + pt_hidden_states[pt_nans] = 0 + tf_hidden_states[pt_nans] = 0 max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) # Debug info (remove when fixed) @@ -148,8 +151,14 @@ class TFModelTesterMixin: tfo = tf_model(inputs_dict) tfo = tfo[0].numpy() pto = pto[0].numpy() - tfo[np.isnan(tfo)] = 0 - pto[np.isnan(pto)] = 0 + tf_nans = np.copy(np.isnan(tfo)) + pt_nans = np.copy(np.isnan(pto)) + + pto[tf_nans] = 0 + tfo[tf_nans] = 0 + pto[pt_nans] = 0 + tfo[pt_nans] = 0 + max_diff = np.amax(np.abs(tfo - pto)) self.assertLessEqual(max_diff, 2e-2)