Clean load keys (#24505)
* Preliminary work on some models * Fix test load missing and make sure nonpersistent buffers are tested * Always ignore nonpersistent buffers if in state_dict * Treat models * More models * Treat remaining models * Fix quality * Fix tests * Remove draft * This test is not needed anymore * Fix copies * Fix last test * Newly added models * Fix last tests * Address review comments
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
|
||||
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
from transformers import RobertaConfig, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
|
||||
@@ -579,23 +578,3 @@ class RobertaModelIntegrationTest(TestCasePlus):
|
||||
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
|
||||
|
||||
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
|
||||
|
||||
# XXX: this might be a candidate for common tests if we have many of those
|
||||
def test_lm_head_ignore_keys(self):
|
||||
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
||||
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
|
||||
config = RobertaConfig.from_pretrained(ROBERTA_TINY)
|
||||
config_tied = deepcopy(config)
|
||||
config_tied.tie_word_embeddings = True
|
||||
config_untied = deepcopy(config)
|
||||
config_untied.tie_word_embeddings = False
|
||||
for cls in [RobertaForMaskedLM, RobertaForCausalLM]:
|
||||
model = cls(config_tied)
|
||||
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)
|
||||
|
||||
# the keys should be different when embeddings aren't tied
|
||||
model = cls(config_untied)
|
||||
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)
|
||||
|
||||
# test that saving works with updated ignore keys - just testing that it doesn't fail
|
||||
model.save_pretrained(self.get_auto_remove_tmp_dir())
|
||||
|
||||
Reference in New Issue
Block a user