no nans
This commit is contained in:
@@ -88,7 +88,13 @@ class CommonTestCases:
|
|||||||
model = model_class.from_pretrained(tmpdirname)
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
after_outputs = model(**inputs_dict)
|
after_outputs = model(**inputs_dict)
|
||||||
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
|
|
||||||
|
# Make sure we don't have nans
|
||||||
|
out_1 = after_outputs[0].numpy()
|
||||||
|
out_2 = outputs[0].numpy()
|
||||||
|
out_1 = out_1[~np.isnan(out_1)]
|
||||||
|
out_2 = out_2[~np.isnan(out_2)]
|
||||||
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
self.assertLessEqual(max_diff, 1e-5)
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
|
|||||||
@@ -92,7 +92,13 @@ class TFCommonTestCases:
|
|||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
model = model_class.from_pretrained(tmpdirname)
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
after_outputs = model(inputs_dict)
|
after_outputs = model(inputs_dict)
|
||||||
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
|
|
||||||
|
# Make sure we don't have nans
|
||||||
|
out_1 = after_outputs[0].numpy()
|
||||||
|
out_2 = outputs[0].numpy()
|
||||||
|
out_1 = out_1[~np.isnan(out_1)]
|
||||||
|
out_2 = out_2[~np.isnan(out_2)]
|
||||||
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
self.assertLessEqual(max_diff, 1e-5)
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
def test_pt_tf_model_equivalence(self):
|
def test_pt_tf_model_equivalence(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user