[Core] [Offloading] Fix saving offloaded submodules (#39280)
* fix counting meta tensors, fix onloading meta tensors Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unrelated fix Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unrelated change Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add clarifying comment Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add test_save_offloaded_model_with_direct_params Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix merge conflict, add decorators Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -3900,12 +3900,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# init state_dict for this shard
|
||||
shard_state_dict = dict.fromkeys(shard, "")
|
||||
for module_name in shard:
|
||||
# skip to collect this weight again
|
||||
if shard_state_dict.get(module_name) != "":
|
||||
continue
|
||||
module = module_map[module_name]
|
||||
# update state dict with onloaded parameters
|
||||
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
|
||||
# note that get_state_dict_from_offload can update with meta tensors
|
||||
# if both a parent module and its descendant are offloaded
|
||||
tensor = shard_state_dict[module_name]
|
||||
if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
|
||||
# update state dict with onloaded parameters
|
||||
module = module_map[module_name]
|
||||
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
|
||||
|
||||
# assign shard to be the completed state dict
|
||||
shard = shard_state_dict
|
||||
|
||||
Reference in New Issue
Block a user