This commit is contained in:
thomwolf
2019-10-11 16:09:42 +02:00
parent 1f5d9513d8
commit 18a3cef7d5
2 changed files with 14 additions and 2 deletions

View File

@@ -88,7 +88,13 @@ class CommonTestCases:
model = model_class.from_pretrained(tmpdirname)
with torch.no_grad():
after_outputs = model(**inputs_dict)
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
# Make sure we don't have nans
out_1 = after_outputs[0].numpy()
out_2 = outputs[0].numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_initialization(self):