Attempting to test automatically the _keys_to_ignore. (#20042)
* Attempting to test automatically the `_keys_to_ignore`.
* Style.
* First fix pass.
* Moving test on its own.
* Another batch.
* Second round removing BatchNorm
* Fixing layoutlmv{2,3} + support older Python.
* Disable miss missing warning.
* Removing dodgy additions.
* Big pass.
* mbart.
* More corrections.
* Fixup.
* Updating test_correct_missing_keys
* Add escape hatch for when the head has no extra params so doesn't need
the missing keys check.
* Fixing test.
* Greener.
* Green ! (except for weird splinter bug).
* Adding a test about `named_parameters` usage.
* Shorten message.
* Apply suggestions from code review
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* After rebase modifications.
* More explicit condition checking.
* Fixing slow tests issues.
* Remove extra pdb.
* Remove print.
* Attempt to make failure consistent + fixing roc_bert.
* Removing the seed (all tests passing with it).
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1468,11 +1468,24 @@ class ModelTesterMixin:
|
||||
base_model_prefix = model.base_model_prefix
|
||||
|
||||
if hasattr(model, base_model_prefix):
|
||||
|
||||
extra_params = {k: v for k, v in model.named_parameters() if not k.startswith(base_model_prefix)}
|
||||
extra_params.update({k: v for k, v in model.named_buffers() if not k.startswith(base_model_prefix)})
|
||||
# Some models define this as None
|
||||
if model._keys_to_ignore_on_load_missing:
|
||||
for key in model._keys_to_ignore_on_load_missing:
|
||||
extra_params.pop(key, None)
|
||||
|
||||
if not extra_params:
|
||||
# In that case, we *are* on a head model, but every
|
||||
# single key is not actual parameters and this is
|
||||
# tested in `test_tied_model_weights_key_ignore` test.
|
||||
continue
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
model.base_model.save_pretrained(temp_dir_name)
|
||||
model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True)
|
||||
with self.subTest(msg=f"Missing keys for {model.__class__.__name__}"):
|
||||
self.assertGreater(len(loading_info["missing_keys"]), 0)
|
||||
self.assertGreater(len(loading_info["missing_keys"]), 0, model.__class__.__name__)
|
||||
|
||||
def test_tie_model_weights(self):
|
||||
if not self.test_torchscript:
|
||||
@@ -1522,6 +1535,54 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
model_tied.save_pretrained(d)
|
||||
|
||||
# We are nuking ALL weights on file, so every parameter should
|
||||
# yell on load. We're going to detect if we yell too much, or too little.
|
||||
with open(os.path.join(d, "pytorch_model.bin"), "wb") as f:
|
||||
torch.save({}, f)
|
||||
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
|
||||
|
||||
# ! Actually we could use `state_dict()` and check iteratively the tensors which are the same (for instance using `tensor.data_ptr()`). to detect the duplicates.
|
||||
# ```python
|
||||
# model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
# "lm_head.weight" in model.state_dict().keys() # True
|
||||
# "lm_head.weight" in model.named_parameters() # False
|
||||
# In [6]: model.lm_head.weight.data_ptr()
|
||||
# Out[6]: 139901378371648
|
||||
# In [9]: model.transformer.wte.weight.data_ptr()
|
||||
# Out[9]: 139901378371648 # Same PTR, it's the same DATA ! we would need to check for stride too to be 100% accurate.
|
||||
# ```
|
||||
|
||||
prefix = f"{model_reloaded.base_model_prefix}."
|
||||
params = dict(model_reloaded.named_parameters())
|
||||
params.update(dict(model_reloaded.named_buffers()))
|
||||
# param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
|
||||
param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
|
||||
|
||||
missing_keys = set(infos["missing_keys"])
|
||||
|
||||
extra_missing = missing_keys - param_names
|
||||
# missed_missing = param_names - missing_keys
|
||||
|
||||
self.assertEqual(
|
||||
extra_missing,
|
||||
set(),
|
||||
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing}",
|
||||
)
|
||||
|
||||
# self.assertEqual(
|
||||
# missed_missing,
|
||||
# set(),
|
||||
# f"This model {model_class.__name__} ignores keys {missed_missing} but they look like real"
|
||||
# " parameters",
|
||||
# )
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user