From bc44e947f371924db854a460484ec46c95e50a35 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 31 Jan 2023 17:32:25 +0100 Subject: [PATCH] Update `Graphormer` and fix its `torchscript` test failures (#21380) * fix Co-authored-by: ydshieh --- .../models/graphormer/modeling_graphormer.py | 14 ++- .../graphormer/test_modeling_graphormer.py | 90 ++++++++++++++++++- 2 files changed, 99 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/graphormer/modeling_graphormer.py b/src/transformers/models/graphormer/modeling_graphormer.py index ec32faddba..82b8b9f876 100755 --- a/src/transformers/models/graphormer/modeling_graphormer.py +++ b/src/transformers/models/graphormer/modeling_graphormer.py @@ -798,9 +798,11 @@ class GraphormerModel(GraphormerPreTrainedModel): attn_edge_type, perturb=None, masked_tokens=None, - return_dict: Optional[bool] = True, + return_dict: Optional[bool] = None, **unused ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + inner_states, graph_rep = self.graph_encoder( input_nodes, input_edges, attn_bias, in_degree, out_degree, spatial_pos, attn_edge_type, perturb=perturb ) @@ -819,7 +821,7 @@ class GraphormerModel(GraphormerPreTrainedModel): input_nodes = torch.nn.functional.linear(input_nodes, self.graph_encoder.embed_tokens.weight) if not return_dict: - return (input_nodes, inner_states) + return tuple(x for x in [input_nodes, inner_states] if x is not None) return BaseModelOutputWithNoAttention(last_hidden_state=input_nodes, hidden_states=inner_states) def max_nodes(self): @@ -860,9 +862,11 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel): spatial_pos, attn_edge_type, labels: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = True, + return_dict: Optional[bool] = None, **unused, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + encoder_outputs = self.encoder( input_nodes, input_edges, @@ -871,12 +875,14 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel): out_degree, spatial_pos, attn_edge_type, + return_dict=True, ) outputs, hidden_states = encoder_outputs["last_hidden_state"], encoder_outputs["hidden_states"] head_outputs = self.classifier(outputs) logits = head_outputs[:, 0, :].contiguous() + loss = None if labels is not None: mask = ~torch.isnan(labels) @@ -891,5 +897,5 @@ class GraphormerForGraphClassification(GraphormerPreTrainedModel): loss = loss_fct(logits[mask], labels[mask]) if not return_dict: - return (loss, logits, hidden_states) + return tuple(x for x in [loss, logits, hidden_states] if x is not None) return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=hidden_states, attentions=None) diff --git a/tests/models/graphormer/test_modeling_graphormer.py b/tests/models/graphormer/test_modeling_graphormer.py index ed692c8868..90698d2781 100644 --- a/tests/models/graphormer/test_modeling_graphormer.py +++ b/tests/models/graphormer/test_modeling_graphormer.py @@ -17,13 +17,15 @@ import copy import inspect +import os +import tempfile import unittest from transformers import GraphormerConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor if is_torch_available(): @@ -255,6 +257,92 @@ class GraphormerModelTest(ModelTesterMixin, unittest.TestCase): self.model_tester = GraphormerModelTester(self) self.config_tester = ConfigTester(self, config_class=GraphormerConfig, has_text_modality=False) + # overwrite from common as `Graphormer` requires more input arguments + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class) + + try: + required_keys = ( + "input_nodes", + "input_edges", + "attn_bias", + "in_degree", + "out_degree", + "spatial_pos", + "attn_edge_type", + ) + required_inputs = tuple(inputs[k] for k in required_keys) + model(*required_inputs) + traced_model = torch.jit.trace(model, required_inputs) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + if layer_name in loaded_model_state_dict: + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + def test_config(self): self.config_tester.run_common_tests()