[modeling utils] revamp from_pretrained(..., low_cpu_mem_usage=True) + tests (#16657)

* add low_cpu_mem_usage tests

* wip: revamping

* wip

* install /usr/bin/time

* wip

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* fix assert

* put the wrapper back

* cleanup; switch to bert-base-cased

* Trigger CI

* Trigger CI
This commit is contained in:
Stas Bekman
2022-04-14 18:10:05 -07:00
committed by GitHub
parent ce2fef2ad2
commit 5da33f8729
4 changed files with 254 additions and 83 deletions

View File

@@ -17,6 +17,7 @@ import inspect
import logging
import os
import re
import shlex
import shutil
import sys
import tempfile
@@ -667,6 +668,20 @@ def require_librosa(test_case):
return test_case
def cmd_exists(cmd):
return shutil.which(cmd) is not None
def require_usr_bin_time(test_case):
"""
Decorator marking a test that requires `/usr/bin/time`
"""
if not cmd_exists("/usr/bin/time"):
return unittest.skip("test requires /usr/bin/time")(test_case)
else:
return test_case
def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch, tf or jax is used)
@@ -1178,6 +1193,39 @@ class TestCasePlus(unittest.TestCase):
return tmp_dir
def python_one_liner_max_rss(self, one_liner_str):
"""
Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
program.
Args:
one_liner_str (`string`):
a python one liner code that gets passed to `python -c`
Returns:
max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.
Requirements:
this helper needs `/usr/bin/time` to be installed (`apt install time`)
Example:
```
one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("t5-large")'
max_rss = self.python_one_liner_max_rss(one_liner_str)
```
"""
if not cmd_exists("/usr/bin/time"):
raise ValueError("/usr/bin/time is required, install with `apt install time`")
cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
with CaptureStd() as cs:
execute_subprocess_async(cmd, env=self.get_env())
# returned data is in KB so convert to bytes
max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
return max_rss
def tearDown(self):
# get_auto_remove_tmp_dir feature: remove registered temp dirs