From 581cf96e0c038b34329f8802cd5b04b66fc87d18 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 28 Mar 2025 16:36:44 +0100 Subject: [PATCH] fix tied weigths issue (#37031) * fix * comment --------- Co-authored-by: ydshieh --- src/transformers/modeling_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 137bd01c01..ac84297280 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -784,6 +784,9 @@ def _load_state_dict_into_meta_model( if is_meta_state_dict: file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) + # Used to fix the issue mentioned in #37031: when loading a model with tied weights in state_dict + `tie_word_embeddings = False`, + # we need to make sure they are not loaded as tied weights! + data_ptrs = set() for param_name, empty_param in state_dict.items(): if param_name not in expected_keys: continue @@ -853,11 +856,19 @@ def _load_state_dict_into_meta_model( if is_fsdp_enabled(): param_device = "cpu" if is_local_dist_rank_0() else "meta" module, param_type = get_module_from_name(model, param_name) + + # avoid tied weights + if param.data_ptr() in data_ptrs: + param = param.clone() + module.load_state_dict( {param_type: param.to(param_device)}, strict=False, assign=True, ) + + # Add `data_ptr` of `model.state_dict()[param_name]` to avoid tied weights + data_ptrs.add(model.state_dict()[param_name].data_ptr()) else: hf_quantizer.create_quantized_param( model, param, param_name, param_device, state_dict, unexpected_keys