Fix loading with only state dict and low_cpu_mem_usage = True (#35217)
* fix loading with only state dict and config * style * add tests --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -4022,8 +4022,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
else:
|
else:
|
||||||
loaded_state_dict_keys = list(state_dict.keys())
|
loaded_state_dict_keys = list(state_dict.keys())
|
||||||
|
if (
|
||||||
if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())):
|
gguf_path is None
|
||||||
|
and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()))
|
||||||
|
and pretrained_model_name_or_path is not None
|
||||||
|
):
|
||||||
# In case some weights need to be kept in float32 and accelerate is not installed,
|
# In case some weights need to be kept in float32 and accelerate is not installed,
|
||||||
# we later on want to take the path where state_dict is not None, that is the one
|
# we later on want to take the path where state_dict is not None, that is the one
|
||||||
# that do not require accelerate.
|
# that do not require accelerate.
|
||||||
@@ -4679,7 +4682,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
|
|
||||||
# For GGUF models `state_dict` is never set to None as the state dict is always small
|
# For GGUF models `state_dict` is never set to None as the state dict is always small
|
||||||
if gguf_path:
|
if gguf_path or low_cpu_mem_usage:
|
||||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||||
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
|
|||||||
@@ -1750,6 +1750,26 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
new_model.generate(random_ids, max_new_tokens=3)
|
new_model.generate(random_ids, max_new_tokens=3)
|
||||||
self.assertTrue(len(w) == 0)
|
self.assertTrue(len(w) == 0)
|
||||||
|
|
||||||
|
def test_load_model_with_state_dict_only(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
config = model.config
|
||||||
|
|
||||||
|
model_loaded = BertModel.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
||||||
|
)
|
||||||
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
def test_load_model_with_state_dict_only_low_cpu_mem_usage(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
config = model.config
|
||||||
|
|
||||||
|
model_loaded = BertModel.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=None, config=config, state_dict=state_dict, low_cpu_mem_usage=True
|
||||||
|
)
|
||||||
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user