fix tied weigths issue (#37031)
* fix * comment --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -784,6 +784,9 @@ def _load_state_dict_into_meta_model(
|
|||||||
if is_meta_state_dict:
|
if is_meta_state_dict:
|
||||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||||
|
|
||||||
|
# Used to fix the issue mentioned in #37031: when loading a model with tied weights in state_dict + `tie_word_embeddings = False`,
|
||||||
|
# we need to make sure they are not loaded as tied weights!
|
||||||
|
data_ptrs = set()
|
||||||
for param_name, empty_param in state_dict.items():
|
for param_name, empty_param in state_dict.items():
|
||||||
if param_name not in expected_keys:
|
if param_name not in expected_keys:
|
||||||
continue
|
continue
|
||||||
@@ -853,11 +856,19 @@ def _load_state_dict_into_meta_model(
|
|||||||
if is_fsdp_enabled():
|
if is_fsdp_enabled():
|
||||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||||
module, param_type = get_module_from_name(model, param_name)
|
module, param_type = get_module_from_name(model, param_name)
|
||||||
|
|
||||||
|
# avoid tied weights
|
||||||
|
if param.data_ptr() in data_ptrs:
|
||||||
|
param = param.clone()
|
||||||
|
|
||||||
module.load_state_dict(
|
module.load_state_dict(
|
||||||
{param_type: param.to(param_device)},
|
{param_type: param.to(param_device)},
|
||||||
strict=False,
|
strict=False,
|
||||||
assign=True,
|
assign=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add `data_ptr` of `model.state_dict()[param_name]` to avoid tied weights
|
||||||
|
data_ptrs.add(model.state_dict()[param_name].data_ptr())
|
||||||
else:
|
else:
|
||||||
hf_quantizer.create_quantized_param(
|
hf_quantizer.create_quantized_param(
|
||||||
model, param, param_name, param_device, state_dict, unexpected_keys
|
model, param, param_name, param_device, state_dict, unexpected_keys
|
||||||
|
|||||||
Reference in New Issue
Block a user