From 1eee1cedfdc854258564c3f301e42bc6fe982e80 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:54:32 +0100 Subject: [PATCH] Fix loading with only state dict and low_cpu_mem_usage = True (#35217) * fix loading with only state dict and config * style * add tests --------- Co-authored-by: Sayak Paul --- src/transformers/modeling_utils.py | 9 ++++++--- tests/utils/test_modeling_utils.py | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 22dd1b7cce..2ea88fb9b0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4022,8 +4022,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: loaded_state_dict_keys = list(state_dict.keys()) - - if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())): + if ( + gguf_path is None + and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())) + and pretrained_model_name_or_path is not None + ): # In case some weights need to be kept in float32 and accelerate is not installed, # we later on want to take the path where state_dict is not None, that is the one # that do not require accelerate. @@ -4679,7 +4682,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) # For GGUF models `state_dict` is never set to None as the state dict is always small - if gguf_path: + if gguf_path or low_cpu_mem_usage: fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 458ddeee5f..31c0d01af7 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1750,6 +1750,26 @@ class ModelUtilsTest(TestCasePlus): new_model.generate(random_ids, max_new_tokens=3) self.assertTrue(len(w) == 0) + def test_load_model_with_state_dict_only(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + state_dict = model.state_dict() + config = model.config + + model_loaded = BertModel.from_pretrained( + pretrained_model_name_or_path=None, config=config, state_dict=state_dict + ) + self.assertTrue(check_models_equal(model, model_loaded)) + + def test_load_model_with_state_dict_only_low_cpu_mem_usage(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + state_dict = model.state_dict() + config = model.config + + model_loaded = BertModel.from_pretrained( + pretrained_model_name_or_path=None, config=config, state_dict=state_dict, low_cpu_mem_usage=True + ) + self.assertTrue(check_models_equal(model, model_loaded)) + @slow @require_torch