Making sure we can use safetensors to serialize all the time. (#22437)

* Making sure we can use safetensors to serialize all the time.

* Expanding the tests for increased coverage.

* Update the test.

* Getting current state of affairs.

* Tentative fix.

* Fixing black version.

* Fixing the worst offenders.

* Try to modify less files.

* Fixing blip_2 (Weird solution right now).

* Fixing deta.

* Fix blip ?

* Missing extra newline.

* No deta modification.

* Adding some comments.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Addressing comments.

* Addressing comments.

* creating warn_once.

* Warning_once !

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-03-31 16:07:35 +02:00
committed by GitHub
parent 516077b3b0
commit d143087d18
6 changed files with 101 additions and 6 deletions

View File

@@ -1736,6 +1736,41 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for ignore_key in self._keys_to_ignore_on_save: for ignore_key in self._keys_to_ignore_on_save:
if ignore_key in state_dict.keys(): if ignore_key in state_dict.keys():
del state_dict[ignore_key] del state_dict[ignore_key]
if safe_serialization:
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
ptrs[tensor.data_ptr()].append(name)
# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
if self._keys_to_ignore_on_load_missing is not None:
for name in names:
matches_pattern = any(re.search(pat, name) for pat in self._keys_to_ignore_on_load_missing)
if matches_pattern and name in state_dict:
del state_dict[name]
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
found = 0
for name in names:
if name in state_dict:
found += 1
if found > 1:
del state_dict[name]
warn_names.add(name)
if len(warn_names) > 0:
logger.warning_once(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
)
# Shard the model if it is too big. # Shard the model if it is too big.
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
@@ -2813,6 +2848,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
missing_keys = list(set(expected_keys) - set(loaded_keys)) missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys))
# Some tensors maybe have been already filled by another key (tied weights).
existing_ptrs = {model_state_dict[k].data_ptr() for k in loaded_keys if k in model_state_dict}
missing_keys = [
k for k in missing_keys if k in model_state_dict and model_state_dict[k].data_ptr() not in existing_ptrs
]
# Some models may have keys that are not in the state by design, removing them before needlessly warning # Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user. # the user.
if cls._keys_to_ignore_on_load_missing is not None: if cls._keys_to_ignore_on_load_missing is not None:

View File

@@ -1238,8 +1238,28 @@ class Blip2Model(Blip2PreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self):
return self.vision_model.embeddings.patch_embedding return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Module:
return self.language_model.get_output_embeddings()
def get_encoder(self):
return self.language_model.get_encoder()
def get_decoder(self):
return self.language_model.get_decoder()
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
@add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING)
def get_text_features( def get_text_features(

View File

@@ -244,7 +244,7 @@ class DetaObjectDetectionOutput(ModelOutput):
def _get_clones(module, N): def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([module for i in range(N)])
def inverse_sigmoid(x, eps=1e-5): def inverse_sigmoid(x, eps=1e-5):

View File

@@ -609,8 +609,6 @@ class LlamaModel(LlamaPreTrainedModel):
class LlamaForCausalLM(LlamaPreTrainedModel): class LlamaForCausalLM(LlamaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.model = LlamaModel(config) self.model = LlamaModel(config)

View File

@@ -357,9 +357,10 @@ class Pix2StructConfig(PretrainedConfig):
initializer_factor=1.0, initializer_factor=1.0,
initializer_range=0.02, initializer_range=0.02,
is_vqa=False, is_vqa=False,
tie_word_embeddings=False,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
if text_config is None: if text_config is None:
text_config = {} text_config = {}

View File

@@ -27,6 +27,7 @@ import tempfile
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
import warnings import warnings
from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@@ -1626,6 +1627,41 @@ class ModelTesterMixin:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
@require_safetensors
def test_can_use_safetensors(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
try:
model_tied.save_pretrained(d, safe_serialization=True)
except Exception as e:
raise Exception(f"Class {model_class.__name__} cannot be saved using safetensors: {e}")
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# Checking the state dicts are correct
reloaded_state = model_reloaded.state_dict()
for k, v in model_tied.state_dict().items():
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
torch.testing.assert_close(
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking the tensor sharing are correct
ptrs = defaultdict(list)
for k, v in model_tied.state_dict().items():
ptrs[v.data_ptr()].append(k)
shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}
for _, shared_names in shared_ptrs.items():
reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names}
self.assertEqual(
len(reloaded_ptrs),
1,
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
)
def test_tied_model_weights_key_ignore(self): def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes: