[roberta] fix lm_head.decoder.weight ignore_key handling (#12446)

* fix lm_head.decoder.weight ignore_key handling

* fix the mutable class variable

* Update src/transformers/models/roberta/modeling_roberta.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* replicate the comment

* make deterministic

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Stas Bekman
2021-07-01 10:31:19 -07:00
committed by GitHub
parent 7f0027db30
commit 2d1d92181a
4 changed files with 48 additions and 8 deletions

View File

@@ -15,9 +15,10 @@
import unittest
from copy import deepcopy
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
@@ -43,6 +44,8 @@ if is_torch_available():
create_position_ids_from_input_ids,
)
ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
class RobertaModelTester:
def __init__(
@@ -475,7 +478,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
@require_torch
class RobertaModelIntegrationTest(unittest.TestCase):
class RobertaModelIntegrationTest(TestCasePlus):
@slow
def test_inference_masked_lm(self):
model = RobertaForMaskedLM.from_pretrained("roberta-base")
@@ -527,3 +530,23 @@ class RobertaModelIntegrationTest(unittest.TestCase):
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
# XXX: this might be a candidate for common tests if we have many of those
def test_lm_head_ignore_keys(self):
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
config = RobertaConfig.from_pretrained(ROBERTA_TINY)
config_tied = deepcopy(config)
config_tied.tie_word_embeddings = True
config_untied = deepcopy(config)
config_untied.tie_word_embeddings = False
for cls in [RobertaForMaskedLM, RobertaForCausalLM]:
model = cls(config_tied)
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)
# the keys should be different when embeddings aren't tied
model = cls(config_untied)
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)
# test that saving works with updated ignore keys - just testing that it doesn't fail
model.save_pretrained(self.get_auto_remove_tmp_dir())