[save_pretrained ] Skip collecting duplicated weight (#36409)
* Skip collecting duplicated weight * format
This commit is contained in:
@@ -3071,6 +3071,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# init state_dict for this shard
|
# init state_dict for this shard
|
||||||
shard_state_dict = {name: "" for name in shard}
|
shard_state_dict = {name: "" for name in shard}
|
||||||
for module_name in shard:
|
for module_name in shard:
|
||||||
|
# skip to collect this weight again
|
||||||
|
if shard_state_dict.get(module_name) != "":
|
||||||
|
continue
|
||||||
module = module_map[module_name]
|
module = module_map[module_name]
|
||||||
# update state dict with onloaded parameters
|
# update state dict with onloaded parameters
|
||||||
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
|
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user