[modeling utils] revamp from_pretrained(..., low_cpu_mem_usage=True) + tests (#16657)
* add low_cpu_mem_usage tests * wip: revamping * wip * install /usr/bin/time * wip * cleanup * cleanup * cleanup * cleanup * cleanup * fix assert * put the wrapper back * cleanup; switch to bert-base-cased * Trigger CI * Trigger CI
This commit is contained in:
@@ -17,6 +17,7 @@ import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -667,6 +668,20 @@ def require_librosa(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def cmd_exists(cmd):
|
||||
return shutil.which(cmd) is not None
|
||||
|
||||
|
||||
def require_usr_bin_time(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires `/usr/bin/time`
|
||||
"""
|
||||
if not cmd_exists("/usr/bin/time"):
|
||||
return unittest.skip("test requires /usr/bin/time")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def get_gpu_count():
|
||||
"""
|
||||
Return the number of available gpus (regardless of whether torch, tf or jax is used)
|
||||
@@ -1178,6 +1193,39 @@ class TestCasePlus(unittest.TestCase):
|
||||
|
||||
return tmp_dir
|
||||
|
||||
def python_one_liner_max_rss(self, one_liner_str):
|
||||
"""
|
||||
Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
|
||||
program.
|
||||
|
||||
Args:
|
||||
one_liner_str (`string`):
|
||||
a python one liner code that gets passed to `python -c`
|
||||
|
||||
Returns:
|
||||
max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.
|
||||
|
||||
Requirements:
|
||||
this helper needs `/usr/bin/time` to be installed (`apt install time`)
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("t5-large")'
|
||||
max_rss = self.python_one_liner_max_rss(one_liner_str)
|
||||
```
|
||||
"""
|
||||
|
||||
if not cmd_exists("/usr/bin/time"):
|
||||
raise ValueError("/usr/bin/time is required, install with `apt install time`")
|
||||
|
||||
cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
|
||||
with CaptureStd() as cs:
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# returned data is in KB so convert to bytes
|
||||
max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
|
||||
return max_rss
|
||||
|
||||
def tearDown(self):
|
||||
|
||||
# get_auto_remove_tmp_dir feature: remove registered temp dirs
|
||||
|
||||
Reference in New Issue
Block a user