Reduce memory leak in _create_and_check_torchscript (#16691)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -598,6 +598,13 @@ class ModelTesterMixin:
|
||||
config.output_hidden_states = True
|
||||
self._create_and_check_torchscript(config, inputs_dict)
|
||||
|
||||
# This is copied from `torch/testing/_internal/jit_utils.py::clear_class_registry`
|
||||
def clear_torch_jit_class_registry(self):
|
||||
|
||||
torch._C._jit_clear_class_registry()
|
||||
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
|
||||
torch.jit._state._clear_class_state()
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
@@ -679,6 +686,10 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
def test_torch_fx(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict)
|
||||
|
||||
Reference in New Issue
Block a user