From 2d1d92181a0f739b6817a74401c51862d28bb409 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 1 Jul 2021 10:31:19 -0700 Subject: [PATCH] [roberta] fix lm_head.decoder.weight ignore_key handling (#12446) * fix lm_head.decoder.weight ignore_key handling * fix the mutable class variable * Update src/transformers/models/roberta/modeling_roberta.py Co-authored-by: Lysandre Debut * replicate the comment * make deterministic Co-authored-by: Lysandre Debut --- src/transformers/modeling_utils.py | 2 +- .../models/roberta/modeling_roberta.py | 21 +++++++++++++-- tests/test_modeling_common.py | 6 ++--- tests/test_modeling_roberta.py | 27 +++++++++++++++++-- 4 files changed, 48 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index feb1bdad07..a6529cb594 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -445,7 +445,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # (and avoid unnecessary warnings). _keys_to_ignore_on_load_unexpected = None # a list of of tensor names to ignore when saving the model (useful for keys that aren't - # trained, but which are deterministic) + # trained, but which are deterministic, or tied variables) _keys_to_ignore_on_save = None is_parallelizable = False diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 787ae588ed..b8228fa6f5 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -603,6 +603,15 @@ class RobertaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) + def update_keys_to_ignore(self, config, del_keys_to_ignore): + """Remove some keys from ignore list""" + if not config.tie_word_embeddings: + # must make a new list, or the class variable gets modified! + self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore] + self._keys_to_ignore_on_load_missing = [ + k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore + ] + ROBERTA_START_DOCSTRING = r""" @@ -864,7 +873,8 @@ class RobertaModel(RobertaPreTrainedModel): """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING ) class RobertaForCausalLM(RobertaPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"] + _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -876,6 +886,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel): self.roberta = RobertaModel(config, add_pooling_layer=False) self.lm_head = RobertaLMHead(config) + # The LM head weights require special treatment only when they are tied with the word embeddings + self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) + self.init_weights() def get_output_embeddings(self): @@ -1010,7 +1023,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel): @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) class RobertaForMaskedLM(RobertaPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"] + _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -1025,6 +1039,9 @@ class RobertaForMaskedLM(RobertaPreTrainedModel): self.roberta = RobertaModel(config, add_pooling_layer=False) self.lm_head = RobertaLMHead(config) + # The LM head weights require special treatment only when they are tied with the word embeddings + self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) + self.init_weights() def get_output_embeddings(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6c2eebb9ac..4fd6217223 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -164,7 +164,7 @@ class ModelTesterMixin: max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) - def test_save_load__keys_to_ignore_on_save(self): + def test_save_load_keys_to_ignore_on_save(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: @@ -175,7 +175,7 @@ class ModelTesterMixin: # check the keys are in the original state_dict for k in _keys_to_ignore_on_save: - self.assertIn(k, model.state_dict()) + self.assertIn(k, model.state_dict().keys(), "\n".join(model.state_dict().keys())) # check that certain keys didn't get saved with the model with tempfile.TemporaryDirectory() as tmpdirname: @@ -183,7 +183,7 @@ class ModelTesterMixin: output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME) state_dict_saved = torch.load(output_model_file) for k in _keys_to_ignore_on_save: - self.assertNotIn(k, state_dict_saved) + self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys())) # Test we can load the state dict in the model, necessary for the checkpointing API in Trainer. load_result = model.load_state_dict(state_dict_saved, strict=False) diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index 168e5073d7..bed69c3469 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -15,9 +15,10 @@ import unittest +from copy import deepcopy from transformers import is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device from .test_configuration_common import ConfigTester from .test_generation_utils import GenerationTesterMixin @@ -43,6 +44,8 @@ if is_torch_available(): create_position_ids_from_input_ids, ) +ROBERTA_TINY = "sshleifer/tiny-distilroberta-base" + class RobertaModelTester: def __init__( @@ -475,7 +478,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas @require_torch -class RobertaModelIntegrationTest(unittest.TestCase): +class RobertaModelIntegrationTest(TestCasePlus): @slow def test_inference_masked_lm(self): model = RobertaForMaskedLM.from_pretrained("roberta-base") @@ -527,3 +530,23 @@ class RobertaModelIntegrationTest(unittest.TestCase): # expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach() self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4)) + + # XXX: this might be a candidate for common tests if we have many of those + def test_lm_head_ignore_keys(self): + keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] + keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"] + config = RobertaConfig.from_pretrained(ROBERTA_TINY) + config_tied = deepcopy(config) + config_tied.tie_word_embeddings = True + config_untied = deepcopy(config) + config_untied.tie_word_embeddings = False + for cls in [RobertaForMaskedLM, RobertaForCausalLM]: + model = cls(config_tied) + self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls) + + # the keys should be different when embeddings aren't tied + model = cls(config_untied) + self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls) + + # test that saving works with updated ignore keys - just testing that it doesn't fail + model.save_pretrained(self.get_auto_remove_tmp_dir())