Making sure we can use safetensors to serialize all the time. (#22437)

* Making sure we can use safetensors to serialize all the time.

* Expanding the tests for increased coverage.

* Update the test.

* Getting current state of affairs.

* Tentative fix.

* Fixing black version.

* Fixing the worst offenders.

* Try to modify less files.

* Fixing blip_2 (Weird solution right now).

* Fixing deta.

* Fix blip ?

* Missing extra newline.

* No deta modification.

* Adding some comments.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Addressing comments.

* Addressing comments.

* creating warn_once.

* Warning_once !

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-03-31 16:07:35 +02:00
committed by GitHub
parent 516077b3b0
commit d143087d18
6 changed files with 101 additions and 6 deletions

View File

@@ -27,6 +27,7 @@ import tempfile
import unittest
import unittest.mock as mock
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
@@ -1626,6 +1627,41 @@ class ModelTesterMixin:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
@require_safetensors
def test_can_use_safetensors(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
try:
model_tied.save_pretrained(d, safe_serialization=True)
except Exception as e:
raise Exception(f"Class {model_class.__name__} cannot be saved using safetensors: {e}")
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# Checking the state dicts are correct
reloaded_state = model_reloaded.state_dict()
for k, v in model_tied.state_dict().items():
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
torch.testing.assert_close(
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking the tensor sharing are correct
ptrs = defaultdict(list)
for k, v in model_tied.state_dict().items():
ptrs[v.data_ptr()].append(k)
shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}
for _, shared_names in shared_ptrs.items():
reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names}
self.assertEqual(
len(reloaded_ptrs),
1,
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
)
def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: