From e0dfd7bcaf7ff0723085f23244a755cc2ed92466 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 16 Jul 2024 09:32:01 -0400 Subject: [PATCH] Speedup model init on CPU (by 10x+ for llama-3-8B as one example) (#31771) * 1,100%! * Clean * Don't touch DS * Experiment with dtype allocation * skip test_load_save_without_tied_weights test * A little faster * Include proper upscaling? * Fixup tests * Potentially skip? * Let's see if this fixes git history * Maintain new dtype * Fin * Rm hook idea for now * New approach, see what breaks * stage * Clean * Stash * Should be fin now, just need to mark failing models * Clean up * Simplify * Deal with weird models * Enc/Dec * Skip w/ reason * Adjust test * Fix test * one more test * Keep experimenting * Fix ref * TO REMOVE: testing feedback CI * Right push * Update tests/utils/test_modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * disable * Add new func * Test nits from Amy * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Adjust comment * Adjust comment on skip * make private * Fin * Should be a not flag * Clarify and rename test --------- Co-authored-by: Marc Sun Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/main_classes/model.md | 4 ++ src/transformers/modeling_utils.py | 64 +++++++++++++++++-- .../modeling_encoder_decoder.py | 1 + .../models/lxmert/modeling_lxmert.py | 1 + .../modeling_vision_encoder_decoder.py | 1 + tests/models/bart/test_modeling_bart.py | 6 ++ .../test_modeling_bigbird_pegasus.py | 6 ++ tests/models/longt5/test_modeling_longt5.py | 12 ++++ tests/models/lxmert/test_modeling_lxmert.py | 6 ++ tests/models/m2m_100/test_modeling_m2m_100.py | 6 ++ tests/models/mbart/test_modeling_mbart.py | 6 ++ .../models/nllb_moe/test_modeling_nllb_moe.py | 6 ++ tests/models/plbart/test_modeling_plbart.py | 6 ++ .../test_modeling_seamless_m4t.py | 12 ++++ .../test_modeling_seamless_m4t_v2.py | 12 ++++ .../test_modeling_switch_transformers.py | 12 ++++ tests/utils/test_modeling_utils.py | 39 +++++++---- 17 files changed, 180 insertions(+), 20 deletions(-) diff --git a/docs/source/en/main_classes/model.md b/docs/source/en/main_classes/model.md index a8ae2ad08b..15345a7b2a 100644 --- a/docs/source/en/main_classes/model.md +++ b/docs/source/en/main_classes/model.md @@ -40,6 +40,10 @@ for text generation, [`~generation.GenerationMixin`] (for the PyTorch models), - push_to_hub - all +Custom models should also include a `_supports_assign_param_buffer`, which determines if superfast init can apply +on the particular model. Signs that your model needs this are if `test_save_and_load_from_pretrained` fails. If so, +set this to `False`. + ## ModuleUtilsMixin [[autodoc]] modeling_utils.ModuleUtilsMixin diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e831ba3613..bf8457309f 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -338,6 +338,32 @@ def dtype_byte_size(dtype): return bit_size // 8 +def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): + """ + Checks if `model_to_load` supports param buffer assignment (such + as when loading in empty weights) by first checking + if the model explicitly disables it, then by ensuring that the state dict keys + are a subset of the model's parameters. + """ + if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: + return False + + # Some models explicitly do not support param buffer assignment + if not getattr(model_to_load, "_supports_param_buffer_assignment", False): + logger.debug( + f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" + ) + return False + + # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype + first_key = list(model_to_load.state_dict().keys())[0] + if start_prefix + first_key in state_dict: + return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype + + # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) + return False + + def shard_checkpoint( state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME ): @@ -657,7 +683,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor] return shared_tensors, identical -def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): +def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] @@ -685,8 +711,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. - def load(module: nn.Module, state_dict, prefix=""): + def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + local_metadata["assign_to_params_buffers"] = assign_to_params_buffers + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict @@ -710,9 +738,9 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): for name, child in module._modules.items(): if child is not None: - load(child, state_dict, prefix + name + ".") + load(child, state_dict, prefix + name + ".", assign_to_params_buffers) - load(model_to_load, state_dict, prefix=start_prefix) + load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so # it's safe to delete it. del state_dict @@ -2852,6 +2880,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those weights are discarded. + If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded + in using the `meta` device and brought into memory once an input is passed through that layer regardless of + `low_cpu_mem_usage`. + Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: @@ -2952,7 +2984,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix low_cpu_mem_usage(`bool`, *optional*): Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Generally should be combined with a `device_map` (such as `"auto"`) for best results. This is an experimental feature and a subject to change at any moment. + + If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without + `device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However, + this should still be enabled if you are passing in a `device_map`. + torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under a specific `dtype`. The different options are: @@ -4018,6 +4056,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix missing_keys = sorted(set(expected_keys) - set(loaded_keys)) unexpected_keys = set(loaded_keys) - set(expected_keys) + # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model # buffers model_buffers = {n for n, _ in model.named_buffers()} @@ -4252,7 +4291,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) else: # Sharded checkpoint or whole but low_cpu_mem_usage==True - error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) + error_msgs = _load_state_dict_into_model( + model_to_load, state_dict, start_prefix, assign_to_params_buffers + ) else: # This should always be a list but, just to be sure. @@ -4280,6 +4324,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if len(resolved_archive_file) > 1: resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") + assign_to_params_buffers = None for shard_file in resolved_archive_file: # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. if shard_file in disk_only_shard_files: @@ -4323,7 +4368,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) error_msgs += new_error_msgs else: - error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + # Sharded checkpoint or whole but low_cpu_mem_usage==True + if assign_to_params_buffers is None: + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) + error_msgs += _load_state_dict_into_model( + model_to_load, state_dict, start_prefix, assign_to_params_buffers + ) # force memory release del state_dict diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index b568850060..db65f6e525 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -178,6 +178,7 @@ class EncoderDecoderModel(PreTrainedModel): base_model_prefix = "encoder_decoder" main_input_name = "input_ids" supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False def __init__( self, diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index a7f0fea8f4..b77b873183 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -773,6 +773,7 @@ class LxmertPreTrainedModel(PreTrainedModel): config_class = LxmertConfig load_tf_weights = load_tf_weights_in_lxmert base_model_prefix = "lxmert" + _supports_param_buffer_assignment = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index b6125fb4db..979bd69de9 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -159,6 +159,7 @@ class VisionEncoderDecoderModel(PreTrainedModel): base_model_prefix = "vision_encoder_decoder" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False def __init__( self, diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index a65ec043de..20d8e3911d 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -512,6 +512,12 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py index 357b91a41e..da909a7c4e 100644 --- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py +++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py @@ -476,6 +476,12 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT self.assertTrue(torch.allclose(outputs1, outputs2, atol=1e-5)) + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + @require_torch @require_sentencepiece diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py index 2b01830946..c0cf21b236 100644 --- a/tests/models/longt5/test_modeling_longt5.py +++ b/tests/models/longt5/test_modeling_longt5.py @@ -758,6 +758,12 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix [encoder_expected_shape] * len(attentions), ) + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + @require_torch class LongT5TGlobalModelTest(LongT5ModelTest): @@ -1097,6 +1103,12 @@ class LongT5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase): [self.model_tester.num_attention_heads, block_len, 3 * block_len], ) + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + class LongT5EncoderOnlyTGlobalModelTest(LongT5EncoderOnlyModelTest): def setUp(self): diff --git a/tests/models/lxmert/test_modeling_lxmert.py b/tests/models/lxmert/test_modeling_lxmert.py index b019d3ed16..1ff8c00261 100644 --- a/tests/models/lxmert/test_modeling_lxmert.py +++ b/tests/models/lxmert/test_modeling_lxmert.py @@ -778,6 +778,12 @@ class LxmertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_save_load_low_cpu_mem_usage_no_safetensors(self): pass + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + @require_torch class LxmertModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index 953144043f..a29a9c8a9e 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -331,6 +331,12 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + def _long_tensor(tok_lst): return torch.tensor(tok_lst, dtype=torch.long, device=torch_device) diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index 943b3fbf6f..4c0bf291c1 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -369,6 +369,12 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi 2, ) + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" diff --git a/tests/models/nllb_moe/test_modeling_nllb_moe.py b/tests/models/nllb_moe/test_modeling_nllb_moe.py index a02dbcaf7f..d8dc3b6ef3 100644 --- a/tests/models/nllb_moe/test_modeling_nllb_moe.py +++ b/tests/models/nllb_moe/test_modeling_nllb_moe.py @@ -346,6 +346,12 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1]) self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0]) + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + @require_torch @require_sentencepiece diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py index 9c16214a1c..7a0eebd7bd 100644 --- a/tests/models/plbart/test_modeling_plbart.py +++ b/tests/models/plbart/test_modeling_plbart.py @@ -323,6 +323,12 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix def test_sample_generate(self): pass + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" diff --git a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py index 2647c2eac7..45796b4574 100644 --- a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py +++ b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py @@ -506,6 +506,12 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase): def test_training_gradient_checkpointing_use_reentrant_false(self): pass + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + def test_attention_outputs(self): # expected length is subsampled so need to change a bit this test if not self.has_attentions: @@ -758,6 +764,12 @@ class SeamlessM4TModelWithTextInputTest( def test_retain_grad_hidden_states_attentions(self): pass + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + @require_torch class SeamlessM4TGenerationTest(unittest.TestCase): diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index f450dca519..c891415f19 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -522,6 +522,12 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase) def test_training_gradient_checkpointing_use_reentrant_false(self): pass + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + def test_attention_outputs(self): # expected length is subsampled so need to change a bit this test if not self.has_attentions: @@ -748,6 +754,12 @@ class SeamlessM4Tv2ModelWithTextInputTest(ModelTesterMixin, GenerationTesterMixi def test_training_gradient_checkpointing_use_reentrant_false(self): pass + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + @require_torch class SeamlessM4Tv2GenerationTest(unittest.TestCase): diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 71b852df6e..13241151a8 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -720,6 +720,12 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + class SwitchTransformersEncoderOnlyModelTester: def __init__( @@ -843,6 +849,12 @@ class SwitchTransformersEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + @unittest.skip( + reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245" + ) + def test_load_save_without_tied_weights(self): + pass + def use_task_specific_params(model, task): model.config.update(model.config.task_specific_params[task]) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index ed540fd5e5..deaac17554 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -20,6 +20,7 @@ import os.path import sys import tempfile import threading +import time import unittest import unittest.mock as mock import uuid @@ -894,32 +895,42 @@ class ModelUtilsTest(TestCasePlus): @require_usr_bin_time @require_accelerate @mark.accelerate_tests - def test_from_pretrained_low_cpu_mem_usage_measured(self): - # test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default + def test_from_pretrained_low_cpu_mem_usage_slower(self): + # Before this would test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default + # Now though the memory is the same, we simply test that loading with `low_cpu_mem_usage` winds up being *slower* + # (mostly from extra logic needed) - mname = "google-bert/bert-base-cased" + mname = "hf-internal-testing/tiny-random-bert" preamble = "from transformers import AutoModel" one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)' + start_time = time.time() + # Save this output as `max_rss_normal` if testing memory results max_rss_normal = self.python_one_liner_max_rss(one_liner_str) + end_time = time.time() + elapsed_time_normal = end_time - start_time # print(f"{max_rss_normal=}") one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)' + start_time = time.time() + # Save this output as `max_rss_low_mem` if testing memory results max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str) - # print(f"{max_rss_low_mem=}") + end_time = time.time() + elapsed_time_low_mem = end_time - start_time - diff_bytes = max_rss_normal - max_rss_low_mem - diff_percent = diff_bytes / max_rss_low_mem - # print(f"{diff_bytes=}, {diff_percent=}") - # ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but - # measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that - # it's at least 15% less cpu memory consumed + # Should be within 2MBs of each other (overhead) + self.assertAlmostEqual( + max_rss_normal / 1024 / 1024, + max_rss_low_mem / 1024 / 1024, + delta=2, + msg="using `low_cpu_mem_usage` should incur the same memory usage in both cases.", + ) self.assertGreater( - diff_percent, - 0.15, - "should use less CPU memory for low_cpu_mem_usage=True, " - f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}", + elapsed_time_low_mem, + elapsed_time_normal, + "using `low_cpu_mem_usage` should be slower due to extra logic, " + f"but got elapsed_time_normal={elapsed_time_normal} and elapsed_time_low_mem={elapsed_time_low_mem}", ) # if you want to compare things manually, let's first look at the size of the model in bytes