Fix loading models with mismatched sizes (#36463)
* Fix loading model with mismatched sizes * trigger tests
This commit is contained in:
committed by
GitHub
parent
222505c7e4
commit
02776d2c6a
@@ -4907,7 +4907,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model_to_load, state_dict, start_prefix
|
||||
)
|
||||
# at this point the state dict should be on cpu, we don't need to actually read it
|
||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
|
||||
mismatched_names = [name for name, _, _ in mismatched_keys]
|
||||
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
|
||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
|
||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||
else:
|
||||
# This should always be a list but, just to be sure.
|
||||
|
||||
Reference in New Issue
Block a user