From 0af0a5f9698f501240464613d52677f206cdc291 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 24 Apr 2025 16:47:40 +0200 Subject: [PATCH] Fix tied weight loading with TP and loading sub state_dicts (#37758) Update modeling_utils.py --- src/transformers/modeling_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e73862b54e..168d421c99 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4978,7 +4978,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi name: param for name, param in model.named_parameters() if not name.startswith(prefix) } for name, param in parameters_to_initialize.items(): - # First move data to correct + # If it is still on meta here, it means that it's a tied weight that will be tied later anyway -> skip it + if param.device.type == "meta": + continue + # Shard the param to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex) shard_and_distribute_module( model,