Fix loading models with mismatched sizes (#36463)

* Fix loading model with mismatched sizes

* trigger tests
This commit is contained in:
Pavel Iakubovskii
2025-02-28 10:48:59 +00:00
committed by GitHub
parent 222505c7e4
commit 02776d2c6a

View File

@@ -4907,7 +4907,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model_to_load, state_dict, start_prefix
)
# at this point the state dict should be on cpu, we don't need to actually read it
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
mismatched_names = [name for name, _, _ in mismatched_keys]
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
else:
# This should always be a list but, just to be sure.