Reduce memory leak in _create_and_check_torchscript (#16691)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-04-11 18:22:28 +02:00
committed by GitHub
parent 2109afae71
commit 3918d6a9d6

View File

@@ -598,6 +598,13 @@ class ModelTesterMixin:
config.output_hidden_states = True
self._create_and_check_torchscript(config, inputs_dict)
# This is copied from `torch/testing/_internal/jit_utils.py::clear_class_registry`
def clear_torch_jit_class_registry(self):
torch._C._jit_clear_class_registry()
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
torch.jit._state._clear_class_state()
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
return
@@ -679,6 +686,10 @@ class ModelTesterMixin:
self.assertTrue(models_equal)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
def test_torch_fx(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict)