Correct missing keys + test (#3143)
This commit is contained in:
@@ -539,6 +539,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
model_to_load = getattr(model, cls.base_model_prefix)
|
model_to_load = getattr(model, cls.base_model_prefix)
|
||||||
|
|
||||||
load(model_to_load, prefix=start_prefix)
|
load(model_to_load, prefix=start_prefix)
|
||||||
|
|
||||||
|
if model.__class__.__name__ != model_to_load.__class__.__name__:
|
||||||
|
base_model_state_dict = model_to_load.state_dict().keys()
|
||||||
|
head_model_state_dict_without_base_prefix = [
|
||||||
|
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
||||||
|
]
|
||||||
|
|
||||||
|
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
||||||
|
|
||||||
if len(missing_keys) > 0:
|
if len(missing_keys) > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Weights of {} not initialized from pretrained model: {}".format(
|
"Weights of {} not initialized from pretrained model: {}".format(
|
||||||
|
|||||||
@@ -526,6 +526,21 @@ class ModelTesterMixin:
|
|||||||
x = model.get_output_embeddings()
|
x = model.get_output_embeddings()
|
||||||
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
|
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
|
||||||
|
|
||||||
|
def test_correct_missing_keys(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
base_model_prefix = model.base_model_prefix
|
||||||
|
|
||||||
|
if hasattr(model, base_model_prefix):
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||||
|
model.base_model.save_pretrained(temp_dir_name)
|
||||||
|
model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True)
|
||||||
|
|
||||||
|
with self.subTest(msg="Missing keys for {}".format(model.__class__.__name__)):
|
||||||
|
self.assertGreater(len(loading_info["missing_keys"]), 0)
|
||||||
|
|
||||||
def test_tie_model_weights(self):
|
def test_tie_model_weights(self):
|
||||||
if not self.test_torchscript:
|
if not self.test_torchscript:
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user