Fix loading models with mismatched sizes (#36463)
* Fix loading model with mismatched sizes * trigger tests
This commit is contained in:
committed by
GitHub
parent
222505c7e4
commit
02776d2c6a
@@ -4907,7 +4907,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model_to_load, state_dict, start_prefix
|
model_to_load, state_dict, start_prefix
|
||||||
)
|
)
|
||||||
# at this point the state dict should be on cpu, we don't need to actually read it
|
# at this point the state dict should be on cpu, we don't need to actually read it
|
||||||
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict)
|
mismatched_names = [name for name, _, _ in mismatched_keys]
|
||||||
|
fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names}
|
||||||
|
fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict)
|
||||||
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers)
|
||||||
else:
|
else:
|
||||||
# This should always be a list but, just to be sure.
|
# This should always be a list but, just to be sure.
|
||||||
|
|||||||
Reference in New Issue
Block a user