no nans
This commit is contained in:
@@ -88,7 +88,13 @@ class CommonTestCases:
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
with torch.no_grad():
|
||||
after_outputs = model(**inputs_dict)
|
||||
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
|
||||
|
||||
# Make sure we don't have nans
|
||||
out_1 = after_outputs[0].numpy()
|
||||
out_2 = outputs[0].numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_initialization(self):
|
||||
|
||||
@@ -92,7 +92,13 @@ class TFCommonTestCases:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
after_outputs = model(inputs_dict)
|
||||
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
|
||||
|
||||
# Make sure we don't have nans
|
||||
out_1 = after_outputs[0].numpy()
|
||||
out_2 = outputs[0].numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
|
||||
Reference in New Issue
Block a user