[roberta] fix lm_head.decoder.weight ignore_key handling (#12446)
* fix lm_head.decoder.weight ignore_key handling * fix the mutable class variable * Update src/transformers/models/roberta/modeling_roberta.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * replicate the comment * make deterministic Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -445,7 +445,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# (and avoid unnecessary warnings).
|
# (and avoid unnecessary warnings).
|
||||||
_keys_to_ignore_on_load_unexpected = None
|
_keys_to_ignore_on_load_unexpected = None
|
||||||
# a list of of tensor names to ignore when saving the model (useful for keys that aren't
|
# a list of of tensor names to ignore when saving the model (useful for keys that aren't
|
||||||
# trained, but which are deterministic)
|
# trained, but which are deterministic, or tied variables)
|
||||||
_keys_to_ignore_on_save = None
|
_keys_to_ignore_on_save = None
|
||||||
|
|
||||||
is_parallelizable = False
|
is_parallelizable = False
|
||||||
|
|||||||
@@ -603,6 +603,15 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
module.weight.data.fill_(1.0)
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
def update_keys_to_ignore(self, config, del_keys_to_ignore):
|
||||||
|
"""Remove some keys from ignore list"""
|
||||||
|
if not config.tie_word_embeddings:
|
||||||
|
# must make a new list, or the class variable gets modified!
|
||||||
|
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
|
||||||
|
self._keys_to_ignore_on_load_missing = [
|
||||||
|
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
ROBERTA_START_DOCSTRING = r"""
|
ROBERTA_START_DOCSTRING = r"""
|
||||||
|
|
||||||
@@ -864,7 +873,8 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||||||
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
|
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
|
||||||
)
|
)
|
||||||
class RobertaForCausalLM(RobertaPreTrainedModel):
|
class RobertaForCausalLM(RobertaPreTrainedModel):
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"]
|
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
||||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@@ -876,6 +886,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||||||
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
||||||
self.lm_head = RobertaLMHead(config)
|
self.lm_head = RobertaLMHead(config)
|
||||||
|
|
||||||
|
# The LM head weights require special treatment only when they are tied with the word embeddings
|
||||||
|
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
@@ -1010,7 +1023,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||||
class RobertaForMaskedLM(RobertaPreTrainedModel):
|
class RobertaForMaskedLM(RobertaPreTrainedModel):
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"]
|
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
||||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@@ -1025,6 +1039,9 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
|
|||||||
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
||||||
self.lm_head = RobertaLMHead(config)
|
self.lm_head = RobertaLMHead(config)
|
||||||
|
|
||||||
|
# The LM head weights require special treatment only when they are tied with the word embeddings
|
||||||
|
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ class ModelTesterMixin:
|
|||||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
self.assertLessEqual(max_diff, 1e-5)
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
def test_save_load__keys_to_ignore_on_save(self):
|
def test_save_load_keys_to_ignore_on_save(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -175,7 +175,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# check the keys are in the original state_dict
|
# check the keys are in the original state_dict
|
||||||
for k in _keys_to_ignore_on_save:
|
for k in _keys_to_ignore_on_save:
|
||||||
self.assertIn(k, model.state_dict())
|
self.assertIn(k, model.state_dict().keys(), "\n".join(model.state_dict().keys()))
|
||||||
|
|
||||||
# check that certain keys didn't get saved with the model
|
# check that certain keys didn't get saved with the model
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
@@ -183,7 +183,7 @@ class ModelTesterMixin:
|
|||||||
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
||||||
state_dict_saved = torch.load(output_model_file)
|
state_dict_saved = torch.load(output_model_file)
|
||||||
for k in _keys_to_ignore_on_save:
|
for k in _keys_to_ignore_on_save:
|
||||||
self.assertNotIn(k, state_dict_saved)
|
self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys()))
|
||||||
|
|
||||||
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
|
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
|
||||||
load_result = model.load_state_dict(state_dict_saved, strict=False)
|
load_result = model.load_state_dict(state_dict_saved, strict=False)
|
||||||
|
|||||||
@@ -15,9 +15,10 @@
|
|||||||
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_generation_utils import GenerationTesterMixin
|
from .test_generation_utils import GenerationTesterMixin
|
||||||
@@ -43,6 +44,8 @@ if is_torch_available():
|
|||||||
create_position_ids_from_input_ids,
|
create_position_ids_from_input_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
|
||||||
|
|
||||||
|
|
||||||
class RobertaModelTester:
|
class RobertaModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -475,7 +478,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class RobertaModelIntegrationTest(unittest.TestCase):
|
class RobertaModelIntegrationTest(TestCasePlus):
|
||||||
@slow
|
@slow
|
||||||
def test_inference_masked_lm(self):
|
def test_inference_masked_lm(self):
|
||||||
model = RobertaForMaskedLM.from_pretrained("roberta-base")
|
model = RobertaForMaskedLM.from_pretrained("roberta-base")
|
||||||
@@ -527,3 +530,23 @@ class RobertaModelIntegrationTest(unittest.TestCase):
|
|||||||
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
|
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
|
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
|
||||||
|
|
||||||
|
# XXX: this might be a candidate for common tests if we have many of those
|
||||||
|
def test_lm_head_ignore_keys(self):
|
||||||
|
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
|
||||||
|
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
|
||||||
|
config = RobertaConfig.from_pretrained(ROBERTA_TINY)
|
||||||
|
config_tied = deepcopy(config)
|
||||||
|
config_tied.tie_word_embeddings = True
|
||||||
|
config_untied = deepcopy(config)
|
||||||
|
config_untied.tie_word_embeddings = False
|
||||||
|
for cls in [RobertaForMaskedLM, RobertaForCausalLM]:
|
||||||
|
model = cls(config_tied)
|
||||||
|
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)
|
||||||
|
|
||||||
|
# the keys should be different when embeddings aren't tied
|
||||||
|
model = cls(config_untied)
|
||||||
|
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)
|
||||||
|
|
||||||
|
# test that saving works with updated ignore keys - just testing that it doesn't fail
|
||||||
|
model.save_pretrained(self.get_auto_remove_tmp_dir())
|
||||||
|
|||||||
Reference in New Issue
Block a user