Fix safetensors failing tests (#27231)
* Fix Kosmos2 * Fix ProphetNet * Fix MarianMT * Fix M4T * XLM ProphetNet * ProphetNet fix * XLM ProphetNet * Final M4T fixes * Tied weights keys * Revert M4T changes * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -304,6 +304,25 @@ class Kosmos2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_load_save_without_tied_weights(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.text_config.tie_word_embeddings = False
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
model.save_pretrained(d)
|
||||
|
||||
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
|
||||
# Checking the state dicts are correct
|
||||
reloaded_state = model_reloaded.state_dict()
|
||||
for k, v in model.state_dict().items():
|
||||
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
|
||||
torch.testing.assert_close(
|
||||
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
|
||||
)
|
||||
# Checking there was no complain of missing weights
|
||||
self.assertEqual(infos["missing_keys"], [])
|
||||
|
||||
# overwrite from common in order to use `self.model_tester.text_model_tester.num_hidden_layers`
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
|
||||
@@ -76,7 +76,7 @@ from transformers.testing_utils import (
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
GENERATION_CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
@@ -91,6 +91,7 @@ if is_accelerate_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch import nn
|
||||
|
||||
@@ -311,17 +312,20 @@ class ModelTesterMixin:
|
||||
# check that certain keys didn't get saved with the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
||||
state_dict_saved = torch.load(output_model_file)
|
||||
output_model_file = os.path.join(tmpdirname, SAFE_WEIGHTS_NAME)
|
||||
state_dict_saved = safe_load_file(output_model_file)
|
||||
|
||||
for k in _keys_to_ignore_on_save:
|
||||
self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys()))
|
||||
|
||||
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
|
||||
load_result = model.load_state_dict(state_dict_saved, strict=False)
|
||||
self.assertTrue(
|
||||
len(load_result.missing_keys) == 0
|
||||
or set(load_result.missing_keys) == set(model._keys_to_ignore_on_save)
|
||||
)
|
||||
keys_to_ignore = set(model._keys_to_ignore_on_save)
|
||||
|
||||
if hasattr(model, "_tied_weights_keys"):
|
||||
keys_to_ignore.update(set(model._tied_weights_keys))
|
||||
|
||||
self.assertTrue(len(load_result.missing_keys) == 0 or set(load_result.missing_keys) == keys_to_ignore)
|
||||
self.assertTrue(len(load_result.unexpected_keys) == 0)
|
||||
|
||||
def test_gradient_checkpointing_backward_compatibility(self):
|
||||
|
||||
Reference in New Issue
Block a user