[roberta] fix lm_head.decoder.weight ignore_key handling (#12446)
* fix lm_head.decoder.weight ignore_key handling * fix the mutable class variable * Update src/transformers/models/roberta/modeling_roberta.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * replicate the comment * make deterministic Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -164,7 +164,7 @@ class ModelTesterMixin:
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_save_load__keys_to_ignore_on_save(self):
|
||||
def test_save_load_keys_to_ignore_on_save(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -175,7 +175,7 @@ class ModelTesterMixin:
|
||||
|
||||
# check the keys are in the original state_dict
|
||||
for k in _keys_to_ignore_on_save:
|
||||
self.assertIn(k, model.state_dict())
|
||||
self.assertIn(k, model.state_dict().keys(), "\n".join(model.state_dict().keys()))
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -183,7 +183,7 @@ class ModelTesterMixin:
|
||||
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
||||
state_dict_saved = torch.load(output_model_file)
|
||||
for k in _keys_to_ignore_on_save:
|
||||
self.assertNotIn(k, state_dict_saved)
|
||||
self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys()))
|
||||
|
||||
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
|
||||
load_result = model.load_state_dict(state_dict_saved, strict=False)
|
||||
|
||||
Reference in New Issue
Block a user