Absolute definitive HeisenDistilBug solve
cc @julien-c @thomwolf
This commit is contained in:
@@ -113,10 +113,13 @@ class TFModelTesterMixin:
|
|||||||
tf_hidden_states = tfo[0].numpy()
|
tf_hidden_states = tfo[0].numpy()
|
||||||
pt_hidden_states = pto[0].numpy()
|
pt_hidden_states = pto[0].numpy()
|
||||||
|
|
||||||
pt_hidden_states[np.isnan(tf_hidden_states)] = 0
|
tf_nans = np.copy(np.isnan(tf_hidden_states))
|
||||||
tf_hidden_states[np.isnan(tf_hidden_states)] = 0
|
pt_nans = np.copy(np.isnan(pt_hidden_states))
|
||||||
pt_hidden_states[np.isnan(pt_hidden_states)] = 0
|
|
||||||
tf_hidden_states[np.isnan(pt_hidden_states)] = 0
|
pt_hidden_states[tf_nans] = 0
|
||||||
|
tf_hidden_states[tf_nans] = 0
|
||||||
|
pt_hidden_states[pt_nans] = 0
|
||||||
|
tf_hidden_states[pt_nans] = 0
|
||||||
|
|
||||||
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
||||||
# Debug info (remove when fixed)
|
# Debug info (remove when fixed)
|
||||||
@@ -148,8 +151,14 @@ class TFModelTesterMixin:
|
|||||||
tfo = tf_model(inputs_dict)
|
tfo = tf_model(inputs_dict)
|
||||||
tfo = tfo[0].numpy()
|
tfo = tfo[0].numpy()
|
||||||
pto = pto[0].numpy()
|
pto = pto[0].numpy()
|
||||||
tfo[np.isnan(tfo)] = 0
|
tf_nans = np.copy(np.isnan(tfo))
|
||||||
pto[np.isnan(pto)] = 0
|
pt_nans = np.copy(np.isnan(pto))
|
||||||
|
|
||||||
|
pto[tf_nans] = 0
|
||||||
|
tfo[tf_nans] = 0
|
||||||
|
pto[pt_nans] = 0
|
||||||
|
tfo[pt_nans] = 0
|
||||||
|
|
||||||
max_diff = np.amax(np.abs(tfo - pto))
|
max_diff = np.amax(np.abs(tfo - pto))
|
||||||
self.assertLessEqual(max_diff, 2e-2)
|
self.assertLessEqual(max_diff, 2e-2)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user