Correct missing keys + test (#3143)

This commit is contained in:
Lysandre Debut
2020-03-05 17:01:54 -05:00
committed by GitHub
parent 1741d740f2
commit 0001d05686
2 changed files with 24 additions and 0 deletions

View File

@@ -539,6 +539,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
if model.__class__.__name__ != model_to_load.__class__.__name__:
base_model_state_dict = model_to_load.state_dict().keys()
head_model_state_dict_without_base_prefix = [
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
]
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(