CPU/GPU memory benchmarking utilities - Remove support for python 3.5 (now only 3.6+) (#3186)

* memory benchmark rss

* have both forward pass and line-by-line mem tracing

* cleaned up tracing

* refactored and cleaning up API

* no f-strings yet...

* add GPU mem logging

* fix GPU memory monitoring

* style and quality

* clean up and doc

* update with comments

* Switching to python 3.6+

* fix quality
This commit is contained in:
Thomas Wolf
2020-03-17 15:17:11 +01:00
committed by GitHub
parent bd3feddf67
commit 2187c49f5c
11 changed files with 565 additions and 32 deletions

View File

@@ -39,6 +39,7 @@ from .file_utils import (
logger = logging.getLogger(__name__)
try:
from torch.nn import Identity
except ImportError:
@@ -66,6 +67,47 @@ class ModuleUtilsMixin:
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
@staticmethod
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
try:
import psutil
except (ImportError):
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
process = psutil.Process(os.getpid())
mem = process.memory_info()
module.mem_rss_pre_forward = mem.rss
return None
@staticmethod
def _hook_rss_memory_post_forward(module, *args, **kwargs):
try:
import psutil
except (ImportError):
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
process = psutil.Process(os.getpid())
mem = process.memory_info()
module.mem_rss_post_forward = mem.rss
mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
return None
def add_memory_hooks(self):
""" Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
"""
for module in self.modules():
module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
module.register_forward_hook(self._hook_rss_memory_post_forward)
self.reset_memory_hooks_state()
def reset_memory_hooks_state(self):
for module in self.modules():
module.mem_rss_diff = 0
module.mem_rss_post_forward = 0
module.mem_rss_pre_forward = 0
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
r""" Base class for all models.