Clean load keys (#24505)

* Preliminary work on some models

* Fix test load missing and make sure nonpersistent buffers are tested

* Always ignore nonpersistent buffers if in state_dict

* Treat models

* More models

* Treat remaining models

* Fix quality

* Fix tests

* Remove draft

* This test is not needed anymore

* Fix copies

* Fix last test

* Newly added models

* Fix last tests

* Address review comments
This commit is contained in:
Sylvain Gugger
2023-06-27 14:45:40 -04:00
committed by GitHub
parent 53194991e9
commit 8e5d1619b3
138 changed files with 320 additions and 1140 deletions

View File

@@ -15,7 +15,6 @@
import unittest
from copy import deepcopy
from transformers import RobertaConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
@@ -579,23 +578,3 @@ class RobertaModelIntegrationTest(TestCasePlus):
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
# XXX: this might be a candidate for common tests if we have many of those
def test_lm_head_ignore_keys(self):
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
config = RobertaConfig.from_pretrained(ROBERTA_TINY)
config_tied = deepcopy(config)
config_tied.tie_word_embeddings = True
config_untied = deepcopy(config)
config_untied.tie_word_embeddings = False
for cls in [RobertaForMaskedLM, RobertaForCausalLM]:
model = cls(config_tied)
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)
# the keys should be different when embeddings aren't tied
model = cls(config_untied)
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)
# test that saving works with updated ignore keys - just testing that it doesn't fail
model.save_pretrained(self.get_auto_remove_tmp_dir())

View File

@@ -1562,7 +1562,7 @@ class ModelTesterMixin:
@require_safetensors
def test_can_use_safetensors(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
@@ -1579,6 +1579,8 @@ class ModelTesterMixin:
torch.testing.assert_close(
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking there was no complain of missing weights
self.assertEqual(infos["missing_keys"], [])
# Checking the tensor sharing are correct
ptrs = defaultdict(list)
@@ -1595,6 +1597,25 @@ class ModelTesterMixin:
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
)
def test_load_save_without_tied_weights(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = False
for model_class in self.all_model_classes:
model = model_class(config)
with tempfile.TemporaryDirectory() as d:
model.save_pretrained(d)
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# Checking the state dicts are correct
reloaded_state = model_reloaded.state_dict()
for k, v in model.state_dict().items():
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
torch.testing.assert_close(
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking there was no complain of missing weights
self.assertEqual(infos["missing_keys"], [])
def test_tied_weights_keys(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = True
@@ -1620,55 +1641,72 @@ class ModelTesterMixin:
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(tied_params, [])
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)
def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def test_model_weights_reload_no_missing_tied_weights(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
model_tied.save_pretrained(d)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
# We are nuking ALL weights on file, so every parameter should
# yell on load. We're going to detect if we yell too much, or too little.
with open(os.path.join(d, "pytorch_model.bin"), "wb") as f:
with open(os.path.join(tmp_dir, "pytorch_model.bin"), "wb") as f:
torch.save({}, f)
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# ! Actually we could use `state_dict()` and check iteratively the tensors which are the same (for instance using `tensor.data_ptr()`). to detect the duplicates.
# ```python
# model = GPT2LMHeadModel.from_pretrained("gpt2")
# "lm_head.weight" in model.state_dict().keys() # True
# "lm_head.weight" in model.named_parameters() # False
# In [6]: model.lm_head.weight.data_ptr()
# Out[6]: 139901378371648
# In [9]: model.transformer.wte.weight.data_ptr()
# Out[9]: 139901378371648 # Same PTR, it's the same DATA ! we would need to check for stride too to be 100% accurate.
# ```
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
prefix = f"{model_reloaded.base_model_prefix}."
params = dict(model_reloaded.named_parameters())
params.update(dict(model_reloaded.named_buffers()))
# param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()}
missing_keys = set(infos["missing_keys"])
extra_missing = missing_keys - param_names
# missed_missing = param_names - missing_keys
# Remove tied weights from extra missing: they are normally not warned as missing if their tied
# counterpart is present but here there are no weights at all so we do get the warning.
ptrs = collections.defaultdict(list)
for name, tensor in model_reloaded.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
for group in tied_params:
group = {k[len(prefix) :] if k.startswith(prefix) else k for k in group}
# We remove the group from extra_missing if not all weights from group are in it
if len(group - extra_missing) > 0:
extra_missing = extra_missing - set(group)
self.assertEqual(
extra_missing,
set(),
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing}",
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing}. "
f"For debugging, tied parameters are {tied_params}",
)
# self.assertEqual(
# missed_missing,
# set(),
# f"This model {model_class.__name__} ignores keys {missed_missing} but they look like real"
# " parameters",
# )
missed_missing = param_names - missing_keys
# Remove nonpersistent buffers from missed_missing
buffers = [n for n, _ in model_reloaded.named_buffers()]
nonpersistent_buffers = {n for n in buffers if n not in model_reloaded.state_dict()}
nonpersistent_buffers = {
k[len(prefix) :] if k.startswith(prefix) else k for k in nonpersistent_buffers
}
missed_missing = missed_missing - nonpersistent_buffers
if model_reloaded._keys_to_ignore_on_load_missing is None:
expected_missing = set()
else:
expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing)
self.assertEqual(
missed_missing,
expected_missing,
f"This model {model_class.__name__} ignores keys {missed_missing} but they look like real"
" parameters. If they are non persistent buffers make sure to instantiate them with"
" `persistent=False`",
)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -500,8 +500,8 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(os.path.isfile(weights_index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
for i in range(1, 6):
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["bin"])
for i in range(1, 5):
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["bin"])
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))
@@ -546,8 +546,8 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(os.path.isfile(weights_index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
for i in range(1, 6):
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["safetensors"])
for i in range(1, 5):
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["safetensors"])
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))