Fix tied weight loading with TP and loading sub state_dicts (#37758)
Update modeling_utils.py
This commit is contained in:
@@ -4978,7 +4978,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
name: param for name, param in model.named_parameters() if not name.startswith(prefix)
|
name: param for name, param in model.named_parameters() if not name.startswith(prefix)
|
||||||
}
|
}
|
||||||
for name, param in parameters_to_initialize.items():
|
for name, param in parameters_to_initialize.items():
|
||||||
# First move data to correct
|
# If it is still on meta here, it means that it's a tied weight that will be tied later anyway -> skip it
|
||||||
|
if param.device.type == "meta":
|
||||||
|
continue
|
||||||
|
# Shard the param
|
||||||
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex)
|
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex)
|
||||||
shard_and_distribute_module(
|
shard_and_distribute_module(
|
||||||
model,
|
model,
|
||||||
|
|||||||
Reference in New Issue
Block a user