[Core] [Offloading] Fix saving offloaded submodules (#39280)

* fix counting meta tensors, fix onloading meta tensors

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* remove unrelated fix

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* remove unrelated change

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* add clarifying comment

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* add test_save_offloaded_model_with_direct_params

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* fix merge conflict, add decorators

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2025-07-16 04:44:40 -04:00
committed by GitHub
parent add43c4d09
commit 31d81943c9
2 changed files with 52 additions and 6 deletions

View File

@@ -3900,12 +3900,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# init state_dict for this shard
shard_state_dict = dict.fromkeys(shard, "")
for module_name in shard:
# skip to collect this weight again
if shard_state_dict.get(module_name) != "":
continue
module = module_map[module_name]
# update state dict with onloaded parameters
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
# note that get_state_dict_from_offload can update with meta tensors
# if both a parent module and its descendant are offloaded
tensor = shard_state_dict[module_name]
if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
# update state dict with onloaded parameters
module = module_map[module_name]
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
# assign shard to be the completed state dict
shard = shard_state_dict