Making sure we can use safetensors to serialize all the time. (#22437)
* Making sure we can use safetensors to serialize all the time. * Expanding the tests for increased coverage. * Update the test. * Getting current state of affairs. * Tentative fix. * Fixing black version. * Fixing the worst offenders. * Try to modify less files. * Fixing blip_2 (Weird solution right now). * Fixing deta. * Fix blip ? * Missing extra newline. * No deta modification. * Adding some comments. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Addressing comments. * Addressing comments. * creating warn_once. * Warning_once ! --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
@@ -1626,6 +1627,41 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
@require_safetensors
|
||||
def test_can_use_safetensors(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
try:
|
||||
model_tied.save_pretrained(d, safe_serialization=True)
|
||||
except Exception as e:
|
||||
raise Exception(f"Class {model_class.__name__} cannot be saved using safetensors: {e}")
|
||||
|
||||
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
|
||||
# Checking the state dicts are correct
|
||||
reloaded_state = model_reloaded.state_dict()
|
||||
for k, v in model_tied.state_dict().items():
|
||||
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
|
||||
torch.testing.assert_close(
|
||||
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
|
||||
)
|
||||
|
||||
# Checking the tensor sharing are correct
|
||||
ptrs = defaultdict(list)
|
||||
for k, v in model_tied.state_dict().items():
|
||||
ptrs[v.data_ptr()].append(k)
|
||||
|
||||
shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}
|
||||
|
||||
for _, shared_names in shared_ptrs.items():
|
||||
reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names}
|
||||
self.assertEqual(
|
||||
len(reloaded_ptrs),
|
||||
1,
|
||||
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
|
||||
)
|
||||
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user