CPU/GPU memory benchmarking utilities - Remove support for python 3.5 (now only 3.6+) (#3186)
* memory benchmark rss * have both forward pass and line-by-line mem tracing * cleaned up tracing * refactored and cleaning up API * no f-strings yet... * add GPU mem logging * fix GPU memory monitoring * style and quality * clean up and doc * update with comments * Switching to python 3.6+ * fix quality
This commit is contained in:
@@ -39,6 +39,7 @@ from .file_utils import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from torch.nn import Identity
|
||||
except ImportError:
|
||||
@@ -66,6 +67,47 @@ class ModuleUtilsMixin:
|
||||
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
|
||||
return sum(p.numel() for p in params)
|
||||
|
||||
@staticmethod
|
||||
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
|
||||
try:
|
||||
import psutil
|
||||
except (ImportError):
|
||||
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
mem = process.memory_info()
|
||||
module.mem_rss_pre_forward = mem.rss
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _hook_rss_memory_post_forward(module, *args, **kwargs):
|
||||
try:
|
||||
import psutil
|
||||
except (ImportError):
|
||||
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
mem = process.memory_info()
|
||||
module.mem_rss_post_forward = mem.rss
|
||||
mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
|
||||
module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
|
||||
return None
|
||||
|
||||
def add_memory_hooks(self):
|
||||
""" Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
|
||||
Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
|
||||
"""
|
||||
for module in self.modules():
|
||||
module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
|
||||
module.register_forward_hook(self._hook_rss_memory_post_forward)
|
||||
self.reset_memory_hooks_state()
|
||||
|
||||
def reset_memory_hooks_state(self):
|
||||
for module in self.modules():
|
||||
module.mem_rss_diff = 0
|
||||
module.mem_rss_post_forward = 0
|
||||
module.mem_rss_pre_forward = 0
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
r""" Base class for all models.
|
||||
|
||||
Reference in New Issue
Block a user