Correct missing keys + test (#3143)
This commit is contained in:
@@ -539,6 +539,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
model_to_load = getattr(model, cls.base_model_prefix)
|
||||
|
||||
load(model_to_load, prefix=start_prefix)
|
||||
|
||||
if model.__class__.__name__ != model_to_load.__class__.__name__:
|
||||
base_model_state_dict = model_to_load.state_dict().keys()
|
||||
head_model_state_dict_without_base_prefix = [
|
||||
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
||||
]
|
||||
|
||||
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.info(
|
||||
"Weights of {} not initialized from pretrained model: {}".format(
|
||||
|
||||
Reference in New Issue
Block a user