Fix Graphormer test suite (#21419)

* [FIX] path for Graphormer checkpoint

* [FIX] Test suite for graphormer

* [FIX] Update graphormer default num_classes
This commit is contained in:
Clémentine Fourrier
2023-02-02 16:29:13 +01:00
committed by GitHub
parent e006ab51ac
commit 67a3920d85
2 changed files with 9 additions and 13 deletions

View File

@@ -40,7 +40,7 @@ class GraphormerModelTester:
def __init__(
self,
parent,
num_classes=2,
num_classes=1,
num_atoms=512 * 9,
num_edges=512 * 3,
num_in_degree=512,
@@ -614,7 +614,7 @@ class GraphormerModelIntegrationTest(unittest.TestCase):
[3, 3, 4, 3, 3, 3, 3, 4, 4, 3, 4, 2, 2, 0, 0, 0, 0],
]
),
"x": tensor(
"input_nodes": tensor(
[
[[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3]],
[[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [0], [0], [0], [0]],
@@ -1279,15 +1279,11 @@ class GraphormerModelIntegrationTest(unittest.TestCase):
output = model(**model_input)["logits"]
print(output.shape)
print(output)
expected_shape = torch.Size(())
expected_shape = torch.Size((2, 1))
self.assertEqual(output.shape, expected_shape)
# TODO Replace values below with what was printed above.
expected_slice = torch.tensor(
[[[-0.0483, 0.1188, -0.0313], [-0.0606, 0.1435, 0.0199], [-0.0235, 0.1519, 0.0175]]]
expected_logs = torch.tensor(
[[7.6060], [7.4126]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
self.assertTrue(torch.allclose(output, expected_logs, atol=1e-4))