[save_pretrained ] Skip collecting duplicated weight (#36409)

* Skip collecting duplicated weight

* format
This commit is contained in:
wejoncy
2025-02-27 17:57:11 +08:00
committed by GitHub
parent 2d6cc0dfde
commit 17792556b2

View File

@@ -3071,6 +3071,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# init state_dict for this shard # init state_dict for this shard
shard_state_dict = {name: "" for name in shard} shard_state_dict = {name: "" for name in shard}
for module_name in shard: for module_name in shard:
# skip to collect this weight again
if shard_state_dict.get(module_name) != "":
continue
module = module_map[module_name] module = module_map[module_name]
# update state dict with onloaded parameters # update state dict with onloaded parameters
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict) shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)