Speedup model init on CPU (by 10x+ for llama-3-8B as one example) (#31771)
* 1,100%! * Clean * Don't touch DS * Experiment with dtype allocation * skip test_load_save_without_tied_weights test * A little faster * Include proper upscaling? * Fixup tests * Potentially skip? * Let's see if this fixes git history * Maintain new dtype * Fin * Rm hook idea for now * New approach, see what breaks * stage * Clean * Stash * Should be fin now, just need to mark failing models * Clean up * Simplify * Deal with weird models * Enc/Dec * Skip w/ reason * Adjust test * Fix test * one more test * Keep experimenting * Fix ref * TO REMOVE: testing feedback CI * Right push * Update tests/utils/test_modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * disable * Add new func * Test nits from Amy * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Adjust comment * Adjust comment on skip * make private * Fin * Should be a not flag * Clarify and rename test --------- Co-authored-by: Marc Sun <marc@huggingface.co> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,7 @@ import os.path
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
@@ -894,32 +895,42 @@ class ModelUtilsTest(TestCasePlus):
|
||||
@require_usr_bin_time
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
def test_from_pretrained_low_cpu_mem_usage_measured(self):
|
||||
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
|
||||
def test_from_pretrained_low_cpu_mem_usage_slower(self):
|
||||
# Before this would test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
|
||||
# Now though the memory is the same, we simply test that loading with `low_cpu_mem_usage` winds up being *slower*
|
||||
# (mostly from extra logic needed)
|
||||
|
||||
mname = "google-bert/bert-base-cased"
|
||||
mname = "hf-internal-testing/tiny-random-bert"
|
||||
|
||||
preamble = "from transformers import AutoModel"
|
||||
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)'
|
||||
start_time = time.time()
|
||||
# Save this output as `max_rss_normal` if testing memory results
|
||||
max_rss_normal = self.python_one_liner_max_rss(one_liner_str)
|
||||
end_time = time.time()
|
||||
elapsed_time_normal = end_time - start_time
|
||||
# print(f"{max_rss_normal=}")
|
||||
|
||||
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)'
|
||||
start_time = time.time()
|
||||
# Save this output as `max_rss_low_mem` if testing memory results
|
||||
max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str)
|
||||
# print(f"{max_rss_low_mem=}")
|
||||
end_time = time.time()
|
||||
elapsed_time_low_mem = end_time - start_time
|
||||
|
||||
diff_bytes = max_rss_normal - max_rss_low_mem
|
||||
diff_percent = diff_bytes / max_rss_low_mem
|
||||
# print(f"{diff_bytes=}, {diff_percent=}")
|
||||
# ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but
|
||||
# measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that
|
||||
# it's at least 15% less cpu memory consumed
|
||||
# Should be within 2MBs of each other (overhead)
|
||||
self.assertAlmostEqual(
|
||||
max_rss_normal / 1024 / 1024,
|
||||
max_rss_low_mem / 1024 / 1024,
|
||||
delta=2,
|
||||
msg="using `low_cpu_mem_usage` should incur the same memory usage in both cases.",
|
||||
)
|
||||
|
||||
self.assertGreater(
|
||||
diff_percent,
|
||||
0.15,
|
||||
"should use less CPU memory for low_cpu_mem_usage=True, "
|
||||
f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}",
|
||||
elapsed_time_low_mem,
|
||||
elapsed_time_normal,
|
||||
"using `low_cpu_mem_usage` should be slower due to extra logic, "
|
||||
f"but got elapsed_time_normal={elapsed_time_normal} and elapsed_time_low_mem={elapsed_time_low_mem}",
|
||||
)
|
||||
|
||||
# if you want to compare things manually, let's first look at the size of the model in bytes
|
||||
|
||||
Reference in New Issue
Block a user