Fix tied weight loading with TP and loading sub state_dicts (#37758)

Update modeling_utils.py
This commit is contained in:
Cyril Vallez
2025-04-24 16:47:40 +02:00
committed by GitHub
parent 3af24f7e27
commit 0af0a5f969

View File

@@ -4978,7 +4978,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
name: param for name, param in model.named_parameters() if not name.startswith(prefix)
}
for name, param in parameters_to_initialize.items():
# First move data to correct
# If it is still on meta here, it means that it's a tied weight that will be tied later anyway -> skip it
if param.device.type == "meta":
continue
# Shard the param
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex)
shard_and_distribute_module(
model,