Fix safetensors failing tests (#27231)

* Fix Kosmos2

* Fix ProphetNet

* Fix MarianMT

* Fix M4T

* XLM ProphetNet

* ProphetNet fix

* XLM ProphetNet

* Final M4T fixes

* Tied weights keys

* Revert M4T changes

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2023-11-02 15:03:09 +01:00
committed by GitHub
parent 4557a0dede
commit 443bf5e9e2
4 changed files with 84 additions and 11 deletions

View File

@@ -76,7 +76,7 @@ from transformers.testing_utils import (
from transformers.utils import (
CONFIG_NAME,
GENERATION_CONFIG_NAME,
WEIGHTS_NAME,
SAFE_WEIGHTS_NAME,
is_accelerate_available,
is_flax_available,
is_tf_available,
@@ -91,6 +91,7 @@ if is_accelerate_available():
if is_torch_available():
import torch
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from torch import nn
@@ -311,17 +312,20 @@ class ModelTesterMixin:
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
state_dict_saved = torch.load(output_model_file)
output_model_file = os.path.join(tmpdirname, SAFE_WEIGHTS_NAME)
state_dict_saved = safe_load_file(output_model_file)
for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys()))
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
load_result = model.load_state_dict(state_dict_saved, strict=False)
self.assertTrue(
len(load_result.missing_keys) == 0
or set(load_result.missing_keys) == set(model._keys_to_ignore_on_save)
)
keys_to_ignore = set(model._keys_to_ignore_on_save)
if hasattr(model, "_tied_weights_keys"):
keys_to_ignore.update(set(model._tied_weights_keys))
self.assertTrue(len(load_result.missing_keys) == 0 or set(load_result.missing_keys) == keys_to_ignore)
self.assertTrue(len(load_result.unexpected_keys) == 0)
def test_gradient_checkpointing_backward_compatibility(self):