Clean load keys (#24505)
* Preliminary work on some models * Fix test load missing and make sure nonpersistent buffers are tested * Always ignore nonpersistent buffers if in state_dict * Treat models * More models * Treat remaining models * Fix quality * Fix tests * Remove draft * This test is not needed anymore * Fix copies * Fix last test * Newly added models * Fix last tests * Address review comments
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
|
||||
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
from transformers import RobertaConfig, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
|
||||
@@ -579,23 +578,3 @@ class RobertaModelIntegrationTest(TestCasePlus):
|
||||
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
|
||||
|
||||
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
|
||||
|
||||
# XXX: this might be a candidate for common tests if we have many of those
|
||||
def test_lm_head_ignore_keys(self):
|
||||
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
||||
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
|
||||
config = RobertaConfig.from_pretrained(ROBERTA_TINY)
|
||||
config_tied = deepcopy(config)
|
||||
config_tied.tie_word_embeddings = True
|
||||
config_untied = deepcopy(config)
|
||||
config_untied.tie_word_embeddings = False
|
||||
for cls in [RobertaForMaskedLM, RobertaForCausalLM]:
|
||||
model = cls(config_tied)
|
||||
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)
|
||||
|
||||
# the keys should be different when embeddings aren't tied
|
||||
model = cls(config_untied)
|
||||
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)
|
||||
|
||||
# test that saving works with updated ignore keys - just testing that it doesn't fail
|
||||
model.save_pretrained(self.get_auto_remove_tmp_dir())
|
||||
|
||||
@@ -1562,7 +1562,7 @@ class ModelTesterMixin:
|
||||
|
||||
@require_safetensors
|
||||
def test_can_use_safetensors(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
@@ -1579,6 +1579,8 @@ class ModelTesterMixin:
|
||||
torch.testing.assert_close(
|
||||
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
|
||||
)
|
||||
# Checking there was no complain of missing weights
|
||||
self.assertEqual(infos["missing_keys"], [])
|
||||
|
||||
# Checking the tensor sharing are correct
|
||||
ptrs = defaultdict(list)
|
||||
@@ -1595,6 +1597,25 @@ class ModelTesterMixin:
|
||||
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
|
||||
)
|
||||
|
||||
def test_load_save_without_tied_weights(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.tie_word_embeddings = False
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
model.save_pretrained(d)
|
||||
|
||||
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
|
||||
# Checking the state dicts are correct
|
||||
reloaded_state = model_reloaded.state_dict()
|
||||
for k, v in model.state_dict().items():
|
||||
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
|
||||
torch.testing.assert_close(
|
||||
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
|
||||
)
|
||||
# Checking there was no complain of missing weights
|
||||
self.assertEqual(infos["missing_keys"], [])
|
||||
|
||||
def test_tied_weights_keys(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.tie_word_embeddings = True
|
||||
@@ -1620,55 +1641,72 @@ class ModelTesterMixin:
|
||||
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
|
||||
|
||||
tied_params = [group for group in tied_params if len(group) > 1]
|
||||
self.assertListEqual(tied_params, [])
|
||||
self.assertListEqual(
|
||||
tied_params,
|
||||
[],
|
||||
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
|
||||
)
|
||||
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
def test_model_weights_reload_no_missing_tied_weights(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
model_tied.save_pretrained(d)
|
||||
model = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
# We are nuking ALL weights on file, so every parameter should
|
||||
# yell on load. We're going to detect if we yell too much, or too little.
|
||||
with open(os.path.join(d, "pytorch_model.bin"), "wb") as f:
|
||||
with open(os.path.join(tmp_dir, "pytorch_model.bin"), "wb") as f:
|
||||
torch.save({}, f)
|
||||
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
|
||||
|
||||
# ! Actually we could use `state_dict()` and check iteratively the tensors which are the same (for instance using `tensor.data_ptr()`). to detect the duplicates.
|
||||
# ```python
|
||||
# model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
# "lm_head.weight" in model.state_dict().keys() # True
|
||||
# "lm_head.weight" in model.named_parameters() # False
|
||||
# In [6]: model.lm_head.weight.data_ptr()
|
||||
# Out[6]: 139901378371648
|
||||
# In [9]: model.transformer.wte.weight.data_ptr()
|
||||
# Out[9]: 139901378371648 # Same PTR, it's the same DATA ! we would need to check for stride too to be 100% accurate.
|
||||
# ```
|
||||
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
|
||||
|
||||
prefix = f"{model_reloaded.base_model_prefix}."
|
||||
params = dict(model_reloaded.named_parameters())
|
||||
params.update(dict(model_reloaded.named_buffers()))
|
||||
# param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
|
||||
param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()}
|
||||
|
||||
missing_keys = set(infos["missing_keys"])
|
||||
|
||||
extra_missing = missing_keys - param_names
|
||||
# missed_missing = param_names - missing_keys
|
||||
# Remove tied weights from extra missing: they are normally not warned as missing if their tied
|
||||
# counterpart is present but here there are no weights at all so we do get the warning.
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model_reloaded.state_dict().items():
|
||||
ptrs[id_tensor_storage(tensor)].append(name)
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
for group in tied_params:
|
||||
group = {k[len(prefix) :] if k.startswith(prefix) else k for k in group}
|
||||
# We remove the group from extra_missing if not all weights from group are in it
|
||||
if len(group - extra_missing) > 0:
|
||||
extra_missing = extra_missing - set(group)
|
||||
|
||||
self.assertEqual(
|
||||
extra_missing,
|
||||
set(),
|
||||
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing}",
|
||||
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing}. "
|
||||
f"For debugging, tied parameters are {tied_params}",
|
||||
)
|
||||
|
||||
# self.assertEqual(
|
||||
# missed_missing,
|
||||
# set(),
|
||||
# f"This model {model_class.__name__} ignores keys {missed_missing} but they look like real"
|
||||
# " parameters",
|
||||
# )
|
||||
missed_missing = param_names - missing_keys
|
||||
# Remove nonpersistent buffers from missed_missing
|
||||
buffers = [n for n, _ in model_reloaded.named_buffers()]
|
||||
nonpersistent_buffers = {n for n in buffers if n not in model_reloaded.state_dict()}
|
||||
nonpersistent_buffers = {
|
||||
k[len(prefix) :] if k.startswith(prefix) else k for k in nonpersistent_buffers
|
||||
}
|
||||
missed_missing = missed_missing - nonpersistent_buffers
|
||||
|
||||
if model_reloaded._keys_to_ignore_on_load_missing is None:
|
||||
expected_missing = set()
|
||||
else:
|
||||
expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing)
|
||||
self.assertEqual(
|
||||
missed_missing,
|
||||
expected_missing,
|
||||
f"This model {model_class.__name__} ignores keys {missed_missing} but they look like real"
|
||||
" parameters. If they are non persistent buffers make sure to instantiate them with"
|
||||
" `persistent=False`",
|
||||
)
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -500,8 +500,8 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertTrue(os.path.isfile(weights_index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
|
||||
|
||||
for i in range(1, 6):
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["bin"])
|
||||
for i in range(1, 5):
|
||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["bin"])
|
||||
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_name_file))
|
||||
|
||||
@@ -546,8 +546,8 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertTrue(os.path.isfile(weights_index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
for i in range(1, 6):
|
||||
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["safetensors"])
|
||||
for i in range(1, 5):
|
||||
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["safetensors"])
|
||||
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||
self.assertTrue(os.path.isfile(weights_name_file))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user