* Hard error when ignoring tensors. (#27484) * [WIP] Hard error when ignoring tensors. * Better selection/error when saving a checkpoint. - Find all names we should normally drop (those are in the transformers config) - Find all disjoint tensors (for those we can safely trigger a copy to get rid of the sharing before saving) - Clone those disjoint tensors getting rid of the issue - Find all identical names (those should be declared in the config but we try to find them all anyway.) - For all identical names: - If they are in the config, just ignore them everything is fine - If they are not, warn about them. - For all remainder tensors which are shared yet neither identical NOR disjoint. raise a hard error. * Adding a failing test on `main` that passes here. * We don't need to keep the subfolder logic in this test. * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add small tests. * Dead variable. * Fixup. * Fixing tied_Weights_keys on generic models. * Fixup + T5 encoder/decoder tying (with different layers) * Code quality. * Dynamic member. * trigger * Fixing encoder name for other types of encoder/decoder combos. * Fix scoping. * Update .github/workflows/self-scheduled.yml Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fixing the tied_weights after the call. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -30,7 +30,7 @@ from contextlib import contextmanager
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
from zipfile import is_zipfile
|
from zipfile import is_zipfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -573,6 +573,79 @@ def set_initialized_submodules(model, state_dict_keys):
|
|||||||
return not_initialized_submodules
|
return not_initialized_submodules
|
||||||
|
|
||||||
|
|
||||||
|
def _end_ptr(tensor: torch.Tensor) -> int:
|
||||||
|
# extract the end of the pointer if the tensor is a slice of a bigger tensor
|
||||||
|
if tensor.nelement():
|
||||||
|
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||||
|
else:
|
||||||
|
stop = tensor.data_ptr()
|
||||||
|
return stop
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tied_weight_keys(module: nn.Module, prefix=""):
|
||||||
|
tied_weight_keys = []
|
||||||
|
if getattr(module, "_tied_weights_keys", None) is not None:
|
||||||
|
names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys]
|
||||||
|
tied_weight_keys.extend(names)
|
||||||
|
if getattr(module, "_dynamic_tied_weights_keys", None) is not None:
|
||||||
|
names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys]
|
||||||
|
tied_weight_keys.extend(names)
|
||||||
|
for name, submodule in module.named_children():
|
||||||
|
local_prefix = f"{prefix}.{name}" if prefix else name
|
||||||
|
tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix))
|
||||||
|
return tied_weight_keys
|
||||||
|
|
||||||
|
|
||||||
|
def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], List[str]]:
|
||||||
|
filtered_tensors = []
|
||||||
|
for shared in tensors:
|
||||||
|
if len(shared) < 2:
|
||||||
|
filtered_tensors.append(shared)
|
||||||
|
continue
|
||||||
|
|
||||||
|
areas = []
|
||||||
|
for name in shared:
|
||||||
|
tensor = state_dict[name]
|
||||||
|
areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
|
||||||
|
areas.sort()
|
||||||
|
|
||||||
|
_, last_stop, last_name = areas[0]
|
||||||
|
filtered_tensors.append({last_name})
|
||||||
|
for start, stop, name in areas[1:]:
|
||||||
|
if start >= last_stop:
|
||||||
|
filtered_tensors.append({name})
|
||||||
|
else:
|
||||||
|
filtered_tensors[-1].add(name)
|
||||||
|
last_stop = stop
|
||||||
|
disjoint_tensors = []
|
||||||
|
shared_tensors = []
|
||||||
|
for tensors in filtered_tensors:
|
||||||
|
if len(tensors) == 1:
|
||||||
|
disjoint_tensors.append(tensors.pop())
|
||||||
|
else:
|
||||||
|
shared_tensors.append(tensors)
|
||||||
|
return shared_tensors, disjoint_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
|
||||||
|
shared_tensors = []
|
||||||
|
identical = []
|
||||||
|
for shared in tensors:
|
||||||
|
if len(shared) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
areas = collections.defaultdict(set)
|
||||||
|
for name in shared:
|
||||||
|
tensor = state_dict[name]
|
||||||
|
area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
|
||||||
|
areas[area].add(name)
|
||||||
|
if len(areas) == 1:
|
||||||
|
identical.append(shared)
|
||||||
|
else:
|
||||||
|
shared_tensors.append(shared)
|
||||||
|
return shared_tensors, identical
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
old_keys = []
|
old_keys = []
|
||||||
@@ -1646,15 +1719,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
|
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
|
||||||
if hasattr(self, self.base_model_prefix):
|
if hasattr(self, self.base_model_prefix):
|
||||||
self = getattr(self, self.base_model_prefix)
|
self = getattr(self, self.base_model_prefix)
|
||||||
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
tied_weights = self._tie_encoder_decoder_weights(
|
||||||
|
self.encoder, self.decoder, self.base_model_prefix, "encoder"
|
||||||
|
)
|
||||||
|
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
|
||||||
|
# attributed not an instance member, therefore modifying it will modify the entire class
|
||||||
|
# Leading to issues on subsequent calls by different tests or subsequent calls.
|
||||||
|
self._dynamic_tied_weights_keys = tied_weights
|
||||||
|
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
if hasattr(module, "_tie_weights"):
|
if hasattr(module, "_tie_weights"):
|
||||||
module._tie_weights()
|
module._tie_weights()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
|
def _tie_encoder_decoder_weights(
|
||||||
|
encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str
|
||||||
|
):
|
||||||
uninitialized_encoder_weights: List[str] = []
|
uninitialized_encoder_weights: List[str] = []
|
||||||
|
tied_weights: List[str] = []
|
||||||
if decoder.__class__ != encoder.__class__:
|
if decoder.__class__ != encoder.__class__:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
|
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
|
||||||
@@ -1665,8 +1747,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
decoder_pointer: nn.Module,
|
decoder_pointer: nn.Module,
|
||||||
encoder_pointer: nn.Module,
|
encoder_pointer: nn.Module,
|
||||||
module_name: str,
|
module_name: str,
|
||||||
|
base_encoder_name: str,
|
||||||
uninitialized_encoder_weights: List[str],
|
uninitialized_encoder_weights: List[str],
|
||||||
depth=0,
|
depth=0,
|
||||||
|
total_decoder_name="",
|
||||||
|
total_encoder_name="",
|
||||||
):
|
):
|
||||||
assert isinstance(decoder_pointer, nn.Module) and isinstance(
|
assert isinstance(decoder_pointer, nn.Module) and isinstance(
|
||||||
encoder_pointer, nn.Module
|
encoder_pointer, nn.Module
|
||||||
@@ -1674,8 +1759,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if hasattr(decoder_pointer, "weight"):
|
if hasattr(decoder_pointer, "weight"):
|
||||||
assert hasattr(encoder_pointer, "weight")
|
assert hasattr(encoder_pointer, "weight")
|
||||||
encoder_pointer.weight = decoder_pointer.weight
|
encoder_pointer.weight = decoder_pointer.weight
|
||||||
|
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight")
|
||||||
if hasattr(decoder_pointer, "bias"):
|
if hasattr(decoder_pointer, "bias"):
|
||||||
assert hasattr(encoder_pointer, "bias")
|
assert hasattr(encoder_pointer, "bias")
|
||||||
|
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias")
|
||||||
encoder_pointer.bias = decoder_pointer.bias
|
encoder_pointer.bias = decoder_pointer.bias
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1713,19 +1800,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
decoder_modules[decoder_name],
|
decoder_modules[decoder_name],
|
||||||
encoder_modules[encoder_name],
|
encoder_modules[encoder_name],
|
||||||
module_name + "/" + name,
|
module_name + "/" + name,
|
||||||
|
base_encoder_name,
|
||||||
uninitialized_encoder_weights,
|
uninitialized_encoder_weights,
|
||||||
depth=depth + 1,
|
depth=depth + 1,
|
||||||
|
total_encoder_name=f"{total_encoder_name}.{encoder_name}",
|
||||||
|
total_decoder_name=f"{total_decoder_name}.{decoder_name}",
|
||||||
)
|
)
|
||||||
all_encoder_weights.remove(module_name + "/" + encoder_name)
|
all_encoder_weights.remove(module_name + "/" + encoder_name)
|
||||||
|
|
||||||
uninitialized_encoder_weights += list(all_encoder_weights)
|
uninitialized_encoder_weights += list(all_encoder_weights)
|
||||||
|
|
||||||
# tie weights recursively
|
# tie weights recursively
|
||||||
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
|
tie_encoder_to_decoder_recursively(
|
||||||
|
decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights
|
||||||
|
)
|
||||||
|
|
||||||
if len(uninitialized_encoder_weights) > 0:
|
if len(uninitialized_encoder_weights) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
|
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
|
||||||
)
|
)
|
||||||
|
return tied_weights
|
||||||
|
|
||||||
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
||||||
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
|
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
|
||||||
@@ -2402,34 +2496,49 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# These are all the pointers of shared tensors.
|
# These are all the pointers of shared tensors.
|
||||||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
||||||
warn_names = set()
|
error_names = []
|
||||||
|
to_delete_names = set()
|
||||||
|
# Recursively descend to find tied weight keys
|
||||||
|
_tied_weights_keys = _get_tied_weight_keys(self)
|
||||||
for names in shared_ptrs.values():
|
for names in shared_ptrs.values():
|
||||||
# Removing the keys which are declared as known duplicates on
|
# Removing the keys which are declared as known duplicates on
|
||||||
# load. This allows to make sure the name which is kept is consistent.
|
# load. This allows to make sure the name which is kept is consistent.
|
||||||
if self._tied_weights_keys is not None:
|
if _tied_weights_keys is not None:
|
||||||
found = 0
|
found = 0
|
||||||
for name in sorted(names):
|
for name in sorted(names):
|
||||||
matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys)
|
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
|
||||||
if matches_pattern and name in state_dict:
|
if matches_pattern and name in state_dict:
|
||||||
found += 1
|
found += 1
|
||||||
if found < len(names):
|
if found < len(names):
|
||||||
del state_dict[name]
|
to_delete_names.add(name)
|
||||||
|
# We are entering a place where the weights and the transformers configuration do NOT match.
|
||||||
|
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
|
||||||
|
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
|
||||||
|
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
|
||||||
|
for name in disjoint_names:
|
||||||
|
state_dict[name] = state_dict[name].clone()
|
||||||
|
|
||||||
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
||||||
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
||||||
# the key back leading to random tensor. A proper warning will be shown
|
# the key back leading to random tensor. A proper warning will be shown
|
||||||
# during reload (if applicable), but since the file is not necessarily compatible with
|
# during reload (if applicable), but since the file is not necessarily compatible with
|
||||||
# the config, better show a proper warning.
|
# the config, better show a proper warning.
|
||||||
found = 0
|
shared_names, identical_names = _find_identical(shared_names, state_dict)
|
||||||
for name in names:
|
# delete tensors that have identical storage
|
||||||
if name in state_dict:
|
for inames in identical_names:
|
||||||
found += 1
|
known = inames.intersection(to_delete_names)
|
||||||
if found > 1:
|
for name in known:
|
||||||
del state_dict[name]
|
del state_dict[name]
|
||||||
warn_names.add(name)
|
unknown = inames.difference(to_delete_names)
|
||||||
if len(warn_names) > 0:
|
if len(unknown) > 1:
|
||||||
logger.warning_once(
|
error_names.append(unknown)
|
||||||
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
|
|
||||||
|
if shared_names:
|
||||||
|
error_names.append(set(shared_names))
|
||||||
|
|
||||||
|
if len(error_names) > 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Shard the model if it is too big.
|
# Shard the model if it is too big.
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch BERT model."""
|
"""PyTorch BERT model."""
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@@ -1128,7 +1127,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
|
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
|
||||||
)
|
)
|
||||||
class BertLMHeadModel(BertPreTrainedModel):
|
class BertLMHeadModel(BertPreTrainedModel):
|
||||||
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -262,9 +262,16 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
if self.config.tie_encoder_decoder:
|
if self.config.tie_encoder_decoder:
|
||||||
# tie encoder and decoder base model
|
# tie encoder and decoder base model
|
||||||
decoder_base_model_prefix = self.decoder.base_model_prefix
|
decoder_base_model_prefix = self.decoder.base_model_prefix
|
||||||
self._tie_encoder_decoder_weights(
|
tied_weights = self._tie_encoder_decoder_weights(
|
||||||
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
|
self.encoder,
|
||||||
|
self.decoder._modules[decoder_base_model_prefix],
|
||||||
|
self.decoder.base_model_prefix,
|
||||||
|
"encoder",
|
||||||
)
|
)
|
||||||
|
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
|
||||||
|
# attributed not an instance member, therefore modifying it will modify the entire class
|
||||||
|
# Leading to issues on subsequent calls by different tests or subsequent calls.
|
||||||
|
self._dynamic_tied_weights_keys = tied_weights
|
||||||
|
|
||||||
def get_encoder(self):
|
def get_encoder(self):
|
||||||
return self.encoder
|
return self.encoder
|
||||||
|
|||||||
@@ -1343,7 +1343,13 @@ class MarianMTModel(MarianPreTrainedModel):
|
|||||||
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
|
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
|
||||||
if hasattr(self, self.base_model_prefix):
|
if hasattr(self, self.base_model_prefix):
|
||||||
self = getattr(self, self.base_model_prefix)
|
self = getattr(self, self.base_model_prefix)
|
||||||
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
tied_weights = self._tie_encoder_decoder_weights(
|
||||||
|
self.encoder, self.decoder, self.base_model_prefix, "encoder"
|
||||||
|
)
|
||||||
|
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
|
||||||
|
# attributed not an instance member, therefore modifying it will modify the entire class
|
||||||
|
# Leading to issues on subsequent calls by different tests or subsequent calls.
|
||||||
|
self._dynamic_tied_weights_keys = tied_weights
|
||||||
|
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
if hasattr(module, "_tie_weights"):
|
if hasattr(module, "_tie_weights"):
|
||||||
|
|||||||
@@ -1891,9 +1891,16 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
if self.config.tie_encoder_decoder:
|
if self.config.tie_encoder_decoder:
|
||||||
# tie text encoder and decoder base model
|
# tie text encoder and decoder base model
|
||||||
decoder_base_model_prefix = self.decoder.base_model_prefix
|
decoder_base_model_prefix = self.decoder.base_model_prefix
|
||||||
self._tie_encoder_decoder_weights(
|
tied_weights = self._tie_encoder_decoder_weights(
|
||||||
self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
|
self.text_encoder,
|
||||||
|
self.decoder._modules[decoder_base_model_prefix],
|
||||||
|
self.decoder.base_model_prefix,
|
||||||
|
"text_encoder",
|
||||||
)
|
)
|
||||||
|
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
|
||||||
|
# attributed not an instance member, therefore modifying it will modify the entire class
|
||||||
|
# Leading to issues on subsequent calls by different tests or subsequent calls.
|
||||||
|
self._dynamic_tied_weights_keys = tied_weights
|
||||||
|
|
||||||
def get_audio_encoder(self):
|
def get_audio_encoder(self):
|
||||||
return self.audio_encoder
|
return self.audio_encoder
|
||||||
|
|||||||
@@ -1810,9 +1810,16 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
if self.config.tie_encoder_decoder:
|
if self.config.tie_encoder_decoder:
|
||||||
# tie text encoder and decoder base model
|
# tie text encoder and decoder base model
|
||||||
decoder_base_model_prefix = self.decoder.base_model_prefix
|
decoder_base_model_prefix = self.decoder.base_model_prefix
|
||||||
self._tie_encoder_decoder_weights(
|
tied_weights = self._tie_encoder_decoder_weights(
|
||||||
self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
|
self.text_encoder,
|
||||||
|
self.decoder._modules[decoder_base_model_prefix],
|
||||||
|
self.decoder.base_model_prefix,
|
||||||
|
"text_encoder",
|
||||||
)
|
)
|
||||||
|
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
|
||||||
|
# attributed not an instance member, therefore modifying it will modify the entire class
|
||||||
|
# Leading to issues on subsequent calls by different tests or subsequent calls.
|
||||||
|
self._dynamic_tied_weights_keys = tied_weights
|
||||||
|
|
||||||
def get_text_encoder(self):
|
def get_text_encoder(self):
|
||||||
return self.text_encoder
|
return self.text_encoder
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ if is_torch_available():
|
|||||||
_prepare_4d_attention_mask,
|
_prepare_4d_attention_mask,
|
||||||
_prepare_4d_causal_attention_mask,
|
_prepare_4d_causal_attention_mask,
|
||||||
)
|
)
|
||||||
from transformers.modeling_utils import shard_checkpoint
|
from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint
|
||||||
|
|
||||||
# Fake pretrained models for tests
|
# Fake pretrained models for tests
|
||||||
class BaseModel(PreTrainedModel):
|
class BaseModel(PreTrainedModel):
|
||||||
@@ -256,6 +256,26 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
|
|
||||||
self.assertTrue(check_models_equal(model, model_loaded))
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
def test_model_manually_shared_disjointed_tensors_optimum(self):
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
model = BertModel(config)
|
||||||
|
|
||||||
|
# Let's fuse qkv
|
||||||
|
attn = model.encoder.layer[0].attention.self
|
||||||
|
q = attn.query.weight
|
||||||
|
k = attn.key.weight
|
||||||
|
v = attn.value.weight
|
||||||
|
# Force some shared storage
|
||||||
|
qkv = torch.stack([q, k, v], dim=0)
|
||||||
|
attn.query.weight = torch.nn.Parameter(qkv[0])
|
||||||
|
attn.key.weight = torch.nn.Parameter(qkv[1])
|
||||||
|
attn.value.weight = torch.nn.Parameter(qkv[2])
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
model_loaded = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
def test_model_from_pretrained_subfolder_sharded(self):
|
def test_model_from_pretrained_subfolder_sharded(self):
|
||||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
model = BertModel(config)
|
model = BertModel(config)
|
||||||
@@ -2222,3 +2242,40 @@ class Mask4DTestHard(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.assertEqual(decoded_0, decoded_1b)
|
self.assertEqual(decoded_0, decoded_1b)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class TestTensorSharing(TestCasePlus):
|
||||||
|
def test_disjoint(self):
|
||||||
|
main = torch.zeros(10)
|
||||||
|
a = main[:5]
|
||||||
|
b = main[5:]
|
||||||
|
state_dict = {"a": a, "b": b}
|
||||||
|
|
||||||
|
shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict)
|
||||||
|
self.assertEqual(shared_names, [])
|
||||||
|
self.assertEqual(disjoint_names, ["a", "b"])
|
||||||
|
|
||||||
|
a = main[::2]
|
||||||
|
b = main[1::2]
|
||||||
|
state_dict = {"a": a, "b": b}
|
||||||
|
|
||||||
|
shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict)
|
||||||
|
self.assertEqual(shared_names, [{"a", "b"}])
|
||||||
|
self.assertEqual(disjoint_names, [])
|
||||||
|
|
||||||
|
def test_identical(self):
|
||||||
|
a = torch.zeros(10)
|
||||||
|
b = a
|
||||||
|
state_dict = {"a": a, "b": b}
|
||||||
|
|
||||||
|
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
|
||||||
|
self.assertEqual(shared_names, [])
|
||||||
|
self.assertEqual(identical_names, [{"a", "b"}])
|
||||||
|
|
||||||
|
b = a[:5]
|
||||||
|
state_dict = {"a": a, "b": b}
|
||||||
|
|
||||||
|
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
|
||||||
|
self.assertEqual(shared_names, [{"a", "b"}])
|
||||||
|
self.assertEqual(identical_names, [])
|
||||||
|
|||||||
Reference in New Issue
Block a user