diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 22dd1b7cce..2ea88fb9b0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4022,8 +4022,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: loaded_state_dict_keys = list(state_dict.keys()) - - if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())): + if ( + gguf_path is None + and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())) + and pretrained_model_name_or_path is not None + ): # In case some weights need to be kept in float32 and accelerate is not installed, # we later on want to take the path where state_dict is not None, that is the one # that do not require accelerate. @@ -4679,7 +4682,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) # For GGUF models `state_dict` is never set to None as the state dict is always small - if gguf_path: + if gguf_path or low_cpu_mem_usage: fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 458ddeee5f..31c0d01af7 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1750,6 +1750,26 @@ class ModelUtilsTest(TestCasePlus): new_model.generate(random_ids, max_new_tokens=3) self.assertTrue(len(w) == 0) + def test_load_model_with_state_dict_only(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + state_dict = model.state_dict() + config = model.config + + model_loaded = BertModel.from_pretrained( + pretrained_model_name_or_path=None, config=config, state_dict=state_dict + ) + self.assertTrue(check_models_equal(model, model_loaded)) + + def test_load_model_with_state_dict_only_low_cpu_mem_usage(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + state_dict = model.state_dict() + config = model.config + + model_loaded = BertModel.from_pretrained( + pretrained_model_name_or_path=None, config=config, state_dict=state_dict, low_cpu_mem_usage=True + ) + self.assertTrue(check_models_equal(model, model_loaded)) + @slow @require_torch