@@ -285,7 +285,7 @@ class EsmModelIntegrationTest(TestCasePlus):
|
|||||||
self.assertEqual(output.shape, expected_shape)
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]]
|
[[[8.9215, -10.5898, -6.4671], [-6.3967, -13.9114, -1.1212], [-7.7812, -13.9516, -3.7406]]]
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user