Update Graphormer and fix its torchscript test failures (#21380)
* fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -798,9 +798,11 @@ class GraphormerModel(GraphormerPreTrainedModel):
|
||||
attn_edge_type,
|
||||
perturb=None,
|
||||
masked_tokens=None,
|
||||
return_dict: Optional[bool] = True,
|
||||
return_dict: Optional[bool] = None,
|
||||
**unused
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
inner_states, graph_rep = self.graph_encoder(
|
||||
input_nodes, input_edges, attn_bias, in_degree, out_degree, spatial_pos, attn_edge_type, perturb=perturb
|
||||
)
|
||||
@@ -819,7 +821,7 @@ class GraphormerModel(GraphormerPreTrainedModel):
|
||||
input_nodes = torch.nn.functional.linear(input_nodes, self.graph_encoder.embed_tokens.weight)
|
||||
|
||||
if not return_dict:
|
||||
return (input_nodes, inner_states)
|
||||
return tuple(x for x in [input_nodes, inner_states] if x is not None)
|
||||
return BaseModelOutputWithNoAttention(last_hidden_state=input_nodes, hidden_states=inner_states)
|
||||
|
||||
def max_nodes(self):
|
||||
@@ -860,9 +862,11 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
|
||||
spatial_pos,
|
||||
attn_edge_type,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = True,
|
||||
return_dict: Optional[bool] = None,
|
||||
**unused,
|
||||
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
input_nodes,
|
||||
input_edges,
|
||||
@@ -871,12 +875,14 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
|
||||
out_degree,
|
||||
spatial_pos,
|
||||
attn_edge_type,
|
||||
return_dict=True,
|
||||
)
|
||||
outputs, hidden_states = encoder_outputs["last_hidden_state"], encoder_outputs["hidden_states"]
|
||||
|
||||
head_outputs = self.classifier(outputs)
|
||||
logits = head_outputs[:, 0, :].contiguous()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
mask = ~torch.isnan(labels)
|
||||
|
||||
@@ -891,5 +897,5 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
|
||||
loss = loss_fct(logits[mask], labels[mask])
|
||||
|
||||
if not return_dict:
|
||||
return (loss, logits, hidden_states)
|
||||
return tuple(x for x in [loss, logits, hidden_states] if x is not None)
|
||||
return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=hidden_states, attentions=None)
|
||||
|
||||
@@ -17,13 +17,15 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import GraphormerConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -255,6 +257,92 @@ class GraphormerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.model_tester = GraphormerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=GraphormerConfig, has_text_modality=False)
|
||||
|
||||
# overwrite from common as `Graphormer` requires more input arguments
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
try:
|
||||
required_keys = (
|
||||
"input_nodes",
|
||||
"input_edges",
|
||||
"attn_bias",
|
||||
"in_degree",
|
||||
"out_degree",
|
||||
"spatial_pos",
|
||||
"attn_edge_type",
|
||||
)
|
||||
required_inputs = tuple(inputs[k] for k in required_keys)
|
||||
model(*required_inputs)
|
||||
traced_model = torch.jit.trace(model, required_inputs)
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
torch.jit.save(traced_model, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load(pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't load module.")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loaded_model.to(torch_device)
|
||||
loaded_model.eval()
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
loaded_model_state_dict = loaded_model.state_dict()
|
||||
|
||||
non_persistent_buffers = {}
|
||||
for key in loaded_model_state_dict.keys():
|
||||
if key not in model_state_dict.keys():
|
||||
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||
|
||||
loaded_model_state_dict = {
|
||||
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||
}
|
||||
|
||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||
|
||||
model_buffers = list(model.buffers())
|
||||
for non_persistent_buffer in non_persistent_buffers.values():
|
||||
found_buffer = False
|
||||
for i, model_buffer in enumerate(model_buffers):
|
||||
if torch.equal(non_persistent_buffer, model_buffer):
|
||||
found_buffer = True
|
||||
break
|
||||
|
||||
self.assertTrue(found_buffer)
|
||||
model_buffers.pop(i)
|
||||
|
||||
models_equal = True
|
||||
for layer_name, p1 in model_state_dict.items():
|
||||
if layer_name in loaded_model_state_dict:
|
||||
p2 = loaded_model_state_dict[layer_name]
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user