From 17792556b21b4da0dbb9e4b59b39fb34aae4047c Mon Sep 17 00:00:00 2001 From: wejoncy Date: Thu, 27 Feb 2025 17:57:11 +0800 Subject: [PATCH] [save_pretrained ] Skip collecting duplicated weight (#36409) * Skip collecting duplicated weight * format --- src/transformers/modeling_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index db318156ac..1553287c92 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3071,6 +3071,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # init state_dict for this shard shard_state_dict = {name: "" for name in shard} for module_name in shard: + # skip to collect this weight again + if shard_state_dict.get(module_name) != "": + continue module = module_map[module_name] # update state dict with onloaded parameters shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)