Fix Graphormer test suite (#21419)
* [FIX] path for Graphormer checkpoint * [FIX] Test suite for graphormer * [FIX] Update graphormer default num_classes
This commit is contained in:
committed by
GitHub
parent
e006ab51ac
commit
67a3920d85
@@ -39,8 +39,8 @@ class GraphormerConfig(PretrainedConfig):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_classes (`int`, *optional*, defaults to 2):
|
num_classes (`int`, *optional*, defaults to 1):
|
||||||
Number of target classes or labels, set to 1 if the task is a regression task.
|
Number of target classes or labels, set to n for binary classification of n tasks.
|
||||||
num_atoms (`int`, *optional*, defaults to 512*9):
|
num_atoms (`int`, *optional*, defaults to 512*9):
|
||||||
Number of node types in the graphs.
|
Number of node types in the graphs.
|
||||||
num_edges (`int`, *optional*, defaults to 512*3):
|
num_edges (`int`, *optional*, defaults to 512*3):
|
||||||
@@ -134,7 +134,7 @@ class GraphormerConfig(PretrainedConfig):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_classes: int = 2,
|
num_classes: int = 1,
|
||||||
num_atoms: int = 512 * 9,
|
num_atoms: int = 512 * 9,
|
||||||
num_edges: int = 512 * 3,
|
num_edges: int = 512 * 3,
|
||||||
num_in_degree: int = 512,
|
num_in_degree: int = 512,
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class GraphormerModelTester:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
num_classes=2,
|
num_classes=1,
|
||||||
num_atoms=512 * 9,
|
num_atoms=512 * 9,
|
||||||
num_edges=512 * 3,
|
num_edges=512 * 3,
|
||||||
num_in_degree=512,
|
num_in_degree=512,
|
||||||
@@ -614,7 +614,7 @@ class GraphormerModelIntegrationTest(unittest.TestCase):
|
|||||||
[3, 3, 4, 3, 3, 3, 3, 4, 4, 3, 4, 2, 2, 0, 0, 0, 0],
|
[3, 3, 4, 3, 3, 3, 3, 4, 4, 3, 4, 2, 2, 0, 0, 0, 0],
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
"x": tensor(
|
"input_nodes": tensor(
|
||||||
[
|
[
|
||||||
[[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3]],
|
[[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3]],
|
||||||
[[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [0], [0], [0], [0]],
|
[[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [0], [0], [0], [0]],
|
||||||
@@ -1279,15 +1279,11 @@ class GraphormerModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
output = model(**model_input)["logits"]
|
output = model(**model_input)["logits"]
|
||||||
|
|
||||||
print(output.shape)
|
expected_shape = torch.Size((2, 1))
|
||||||
print(output)
|
|
||||||
|
|
||||||
expected_shape = torch.Size(())
|
|
||||||
self.assertEqual(output.shape, expected_shape)
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
|
||||||
# TODO Replace values below with what was printed above.
|
expected_logs = torch.tensor(
|
||||||
expected_slice = torch.tensor(
|
[[7.6060], [7.4126]]
|
||||||
[[[-0.0483, 0.1188, -0.0313], [-0.0606, 0.1435, 0.0199], [-0.0235, 0.1519, 0.0175]]]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(output, expected_logs, atol=1e-4))
|
||||||
|
|||||||
Reference in New Issue
Block a user