@@ -285,7 +285,7 @@ class EsmModelIntegrationTest(TestCasePlus):
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[15.0973, -6.6406, -1.1351], [-0.2209, -9.9622, 4.2109], [-1.6055, -10.0023, 1.5914]]]
|
||||
[[[8.9215, -10.5898, -6.4671], [-6.3967, -13.9114, -1.1212], [-7.7812, -13.9516, -3.7406]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user