[WIP] Hard error when ignoring tensors. (#27484)
* [WIP] Hard error when ignoring tensors. * Better selection/error when saving a checkpoint. - Find all names we should normally drop (those are in the transformers config) - Find all disjoint tensors (for those we can safely trigger a copy to get rid of the sharing before saving) - Clone those disjoint tensors getting rid of the issue - Find all identical names (those should be declared in the config but we try to find them all anyway.) - For all identical names: - If they are in the config, just ignore them everything is fine - If they are not, warn about them. - For all remainder tensors which are shared yet neither identical NOR disjoint. raise a hard error. * Adding a failing test on `main` that passes here. * We don't need to keep the subfolder logic in this test. * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -29,7 +29,7 @@ import warnings
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
from zipfile import is_zipfile
|
from zipfile import is_zipfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -570,6 +570,65 @@ def set_initialized_submodules(model, state_dict_keys):
|
|||||||
return not_initialized_submodules
|
return not_initialized_submodules
|
||||||
|
|
||||||
|
|
||||||
|
def _end_ptr(tensor: torch.Tensor) -> int:
|
||||||
|
# extract the end of the pointer if the tensor is a slice of a bigger tensor
|
||||||
|
if tensor.nelement():
|
||||||
|
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||||
|
else:
|
||||||
|
stop = tensor.data_ptr()
|
||||||
|
return stop
|
||||||
|
|
||||||
|
|
||||||
|
def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
|
||||||
|
filtered_tensors = []
|
||||||
|
for shared in tensors:
|
||||||
|
if len(shared) < 2:
|
||||||
|
filtered_tensors.append(shared)
|
||||||
|
continue
|
||||||
|
|
||||||
|
areas = []
|
||||||
|
for name in shared:
|
||||||
|
tensor = state_dict[name]
|
||||||
|
areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
|
||||||
|
areas.sort()
|
||||||
|
|
||||||
|
_, last_stop, last_name = areas[0]
|
||||||
|
filtered_tensors.append({last_name})
|
||||||
|
for start, stop, name in areas[1:]:
|
||||||
|
if start >= last_stop:
|
||||||
|
filtered_tensors.append({name})
|
||||||
|
else:
|
||||||
|
filtered_tensors[-1].add(name)
|
||||||
|
last_stop = stop
|
||||||
|
disjoint_tensors = []
|
||||||
|
shared_tensors = []
|
||||||
|
for tensors in filtered_tensors:
|
||||||
|
if len(tensors) == 1:
|
||||||
|
disjoint_tensors.append(tensors.pop())
|
||||||
|
else:
|
||||||
|
shared_tensors.append(tensors)
|
||||||
|
return shared_tensors, disjoint_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
|
||||||
|
shared_tensors = []
|
||||||
|
identical = []
|
||||||
|
for shared in tensors:
|
||||||
|
if len(shared) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
areas = collections.defaultdict(set)
|
||||||
|
for name in shared:
|
||||||
|
tensor = state_dict[name]
|
||||||
|
area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
|
||||||
|
areas[area].add(name)
|
||||||
|
if len(areas) == 1:
|
||||||
|
identical.append(shared)
|
||||||
|
else:
|
||||||
|
shared_tensors.append(shared)
|
||||||
|
return shared_tensors, identical
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
old_keys = []
|
old_keys = []
|
||||||
@@ -2382,6 +2441,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# These are all the pointers of shared tensors.
|
# These are all the pointers of shared tensors.
|
||||||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
||||||
warn_names = set()
|
warn_names = set()
|
||||||
|
error_names = set()
|
||||||
|
to_delete_names = set()
|
||||||
for names in shared_ptrs.values():
|
for names in shared_ptrs.values():
|
||||||
# Removing the keys which are declared as known duplicates on
|
# Removing the keys which are declared as known duplicates on
|
||||||
# load. This allows to make sure the name which is kept is consistent.
|
# load. This allows to make sure the name which is kept is consistent.
|
||||||
@@ -2392,25 +2453,42 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if matches_pattern and name in state_dict:
|
if matches_pattern and name in state_dict:
|
||||||
found += 1
|
found += 1
|
||||||
if found < len(names):
|
if found < len(names):
|
||||||
del state_dict[name]
|
to_delete_names.add(name)
|
||||||
|
# We are entering a place where the weights and the transformers configuration do NOT match.
|
||||||
|
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
|
||||||
|
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
|
||||||
|
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
|
||||||
|
for name in disjoint_names:
|
||||||
|
state_dict[name] = state_dict[name].clone()
|
||||||
|
|
||||||
|
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
||||||
|
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
||||||
|
# the key back leading to random tensor. A proper warning will be shown
|
||||||
|
# during reload (if applicable), but since the file is not necessarily compatible with
|
||||||
|
# the config, better show a proper warning.
|
||||||
|
shared_names, identical_names = _find_identical(shared_names, state_dict)
|
||||||
|
# delete tensors that have identical storage
|
||||||
|
for inames in identical_names:
|
||||||
|
known = inames.intersection(to_delete_names)
|
||||||
|
for name in known:
|
||||||
|
del state_dict[name]
|
||||||
|
unknown = sorted(inames.difference(to_delete_names))
|
||||||
|
for name in unknown[1:]:
|
||||||
|
del state_dict[name]
|
||||||
|
warn_names.add(name)
|
||||||
|
|
||||||
|
error_names.update(shared_names)
|
||||||
|
|
||||||
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
|
||||||
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
|
||||||
# the key back leading to random tensor. A proper warning will be shown
|
|
||||||
# during reload (if applicable), but since the file is not necessarily compatible with
|
|
||||||
# the config, better show a proper warning.
|
|
||||||
found = 0
|
|
||||||
for name in names:
|
|
||||||
if name in state_dict:
|
|
||||||
found += 1
|
|
||||||
if found > 1:
|
|
||||||
del state_dict[name]
|
|
||||||
warn_names.add(name)
|
|
||||||
if len(warn_names) > 0:
|
if len(warn_names) > 0:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
|
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(error_names) > 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
|
||||||
|
)
|
||||||
|
|
||||||
# Shard the model if it is too big.
|
# Shard the model if it is too big.
|
||||||
if not _hf_peft_config_loaded:
|
if not _hf_peft_config_loaded:
|
||||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||||
|
|||||||
@@ -257,6 +257,26 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
|
|
||||||
self.assertTrue(check_models_equal(model, model_loaded))
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
def test_model_manually_shared_disjointed_tensors_optimum(self):
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
model = BertModel(config)
|
||||||
|
|
||||||
|
# Let's fuse qkv
|
||||||
|
attn = model.encoder.layer[0].attention.self
|
||||||
|
q = attn.query.weight
|
||||||
|
k = attn.key.weight
|
||||||
|
v = attn.value.weight
|
||||||
|
# Force some shared storage
|
||||||
|
qkv = torch.stack([q, k, v], dim=0)
|
||||||
|
attn.query.weight = torch.nn.Parameter(qkv[0])
|
||||||
|
attn.key.weight = torch.nn.Parameter(qkv[1])
|
||||||
|
attn.value.weight = torch.nn.Parameter(qkv[2])
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
model_loaded = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
def test_model_from_pretrained_subfolder_sharded(self):
|
def test_model_from_pretrained_subfolder_sharded(self):
|
||||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
model = BertModel(config)
|
model = BertModel(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user