From 3918d6a9d67e79c45746607d3ca726ddd641a3d1 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 11 Apr 2022 18:22:28 +0200 Subject: [PATCH] Reduce memory leak in _create_and_check_torchscript (#16691) Co-authored-by: ydshieh --- tests/test_modeling_common.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 752659248d..4b54ec45d5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -598,6 +598,13 @@ class ModelTesterMixin: config.output_hidden_states = True self._create_and_check_torchscript(config, inputs_dict) + # This is copied from `torch/testing/_internal/jit_utils.py::clear_class_registry` + def clear_torch_jit_class_registry(self): + + torch._C._jit_clear_class_registry() + torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() + torch.jit._state._clear_class_state() + def _create_and_check_torchscript(self, config, inputs_dict): if not self.test_torchscript: return @@ -679,6 +686,10 @@ class ModelTesterMixin: self.assertTrue(models_equal) + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + def test_torch_fx(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() self._create_and_check_torch_fx_tracing(config, inputs_dict)