Fix tied weight loading with TP and loading sub state_dicts (#37758)
Update modeling_utils.py
This commit is contained in:
@@ -4978,7 +4978,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
name: param for name, param in model.named_parameters() if not name.startswith(prefix)
|
||||
}
|
||||
for name, param in parameters_to_initialize.items():
|
||||
# First move data to correct
|
||||
# If it is still on meta here, it means that it's a tied weight that will be tied later anyway -> skip it
|
||||
if param.device.type == "meta":
|
||||
continue
|
||||
# Shard the param
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex)
|
||||
shard_and_distribute_module(
|
||||
model,
|
||||
|
||||
Reference in New Issue
Block a user