Pytorch - Lazy initialization of models (#11471)
* lazy_init_weights * remove ipdb * save int * add necessary code * remove unnecessary utils * Update src/transformers/models/t5/modeling_t5.py * clean * add tests * correct * finish tests * finish tests * fix some more tests * fix xlnet & transfo-xl * fix more tests * make sure tests are independent * fix tests more * finist tests * final touches * Update src/transformers/modeling_utils.py * Apply suggestions from code review * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * clean tests * give arg positive name * add more mock weights to xlnet Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8fa8e19429
commit
3e3e41ae20
@@ -195,6 +195,7 @@ class ExamplesTests(TestCasePlus):
|
|||||||
--per_device_train_batch_size=2
|
--per_device_train_batch_size=2
|
||||||
--per_device_eval_batch_size=2
|
--per_device_eval_batch_size=2
|
||||||
--num_train_epochs={epochs}
|
--num_train_epochs={epochs}
|
||||||
|
--seed 7
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
if torch_device != "cuda":
|
if torch_device != "cuda":
|
||||||
|
|||||||
177
src/transformers/modeling_utils.py
Executable file → Normal file
177
src/transformers/modeling_utils.py
Executable file → Normal file
@@ -18,6 +18,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
@@ -50,6 +51,26 @@ from .utils import logging
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_init_weights = True
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def no_init_weights(_enable=True):
|
||||||
|
"""
|
||||||
|
Context manager to globally disable weight initialization to speed up loading large models.
|
||||||
|
|
||||||
|
TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
|
||||||
|
"""
|
||||||
|
global _init_weights
|
||||||
|
if _enable:
|
||||||
|
_init_weights = False
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_init_weights = True
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.nn import Identity
|
from torch.nn import Identity
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -768,16 +789,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
"""
|
"""
|
||||||
Initializes and prunes weights if needed.
|
If needed prunes and maybe initializes weights.
|
||||||
"""
|
"""
|
||||||
# Initialize weights
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
# Prune heads if needed
|
# Prune heads if needed
|
||||||
if self.config.pruned_heads:
|
if self.config.pruned_heads:
|
||||||
self.prune_heads(self.config.pruned_heads)
|
self.prune_heads(self.config.pruned_heads)
|
||||||
|
|
||||||
# Tie weights if needed
|
if _init_weights:
|
||||||
|
# Initialize weights
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
# Tie weights should be skipped when not initializing all weights
|
||||||
|
# since from_pretrained(...) calls tie weights anyways
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
|
|
||||||
def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
|
def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
|
||||||
@@ -956,6 +979,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||||
Please refer to the mirror site for more information.
|
Please refer to the mirror site for more information.
|
||||||
|
_fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`):
|
||||||
|
Whether or not to disable fast initialization.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
One should only disable `_fast_init` to ensure backwards compatibility with
|
||||||
|
``transformers.__version__ < 4.6.0`` for seeded model initialization. This argument will be removed
|
||||||
|
at the next major version. See `pull request 11471
|
||||||
|
<https://github.com/huggingface/transformers/pull/11471>`__ for more information.
|
||||||
|
|
||||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||||
@@ -1012,6 +1045,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
mirror = kwargs.pop("mirror", None)
|
mirror = kwargs.pop("mirror", None)
|
||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
|
_fast_init = kwargs.pop("_fast_init", True)
|
||||||
|
|
||||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||||
if from_pipeline is not None:
|
if from_pipeline is not None:
|
||||||
@@ -1119,7 +1153,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
config.name_or_path = pretrained_model_name_or_path
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
@@ -1127,24 +1160,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
||||||
# and memory copying it on CPU or each GPU first
|
# and memory copying it on CPU or each GPU first
|
||||||
with deepspeed.zero.Init(config=deepspeed_config()):
|
with deepspeed.zero.Init(config=deepspeed_config()):
|
||||||
|
with no_init_weights(_enable=_fast_init):
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
else:
|
else:
|
||||||
|
with no_init_weights(_enable=_fast_init):
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
if state_dict is None and not (from_tf or from_flax):
|
|
||||||
try:
|
|
||||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
|
||||||
except Exception:
|
|
||||||
raise OSError(
|
|
||||||
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
|
||||||
f"at '{resolved_archive_file}'"
|
|
||||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
|
|
||||||
)
|
|
||||||
|
|
||||||
missing_keys = []
|
|
||||||
unexpected_keys = []
|
|
||||||
error_msgs = []
|
|
||||||
|
|
||||||
if from_tf:
|
if from_tf:
|
||||||
if resolved_archive_file.endswith(".index"):
|
if resolved_archive_file.endswith(".index"):
|
||||||
# Load from a TensorFlow 1.X checkpoint - provided by original authors
|
# Load from a TensorFlow 1.X checkpoint - provided by original authors
|
||||||
@@ -1173,6 +1194,39 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
|
if state_dict is None:
|
||||||
|
try:
|
||||||
|
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||||
|
except Exception:
|
||||||
|
raise OSError(
|
||||||
|
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
||||||
|
f"at '{resolved_archive_file}'"
|
||||||
|
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
|
||||||
|
)
|
||||||
|
|
||||||
|
model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
|
||||||
|
model, state_dict, pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure token embedding weights are still tied if needed
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if output_loading_info:
|
||||||
|
loading_info = {
|
||||||
|
"missing_keys": missing_keys,
|
||||||
|
"unexpected_keys": unexpected_keys,
|
||||||
|
"error_msgs": error_msgs,
|
||||||
|
}
|
||||||
|
return model, loading_info
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path):
|
||||||
|
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
old_keys = []
|
old_keys = []
|
||||||
new_keys = []
|
new_keys = []
|
||||||
@@ -1188,17 +1242,53 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
for old_key, new_key in zip(old_keys, new_keys):
|
for old_key, new_key in zip(old_keys, new_keys):
|
||||||
state_dict[new_key] = state_dict.pop(old_key)
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
|
# Retrieve missing & unexpected_keys
|
||||||
|
expected_keys = list(model.state_dict().keys())
|
||||||
|
loaded_keys = list(state_dict.keys())
|
||||||
|
prefix = model.base_model_prefix
|
||||||
|
|
||||||
|
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
||||||
|
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
|
||||||
|
remove_prefix = not has_prefix_module and expects_prefix_module
|
||||||
|
add_prefix = has_prefix_module and not expects_prefix_module
|
||||||
|
|
||||||
|
if remove_prefix:
|
||||||
|
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
|
||||||
|
elif add_prefix:
|
||||||
|
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
||||||
|
|
||||||
|
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||||
|
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||||
|
|
||||||
|
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||||
|
# the user.
|
||||||
|
if cls._keys_to_ignore_on_load_missing is not None:
|
||||||
|
for pat in cls._keys_to_ignore_on_load_missing:
|
||||||
|
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
|
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||||
|
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||||
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
|
# tie unintialized modules
|
||||||
|
unintialized_modules = model.retrieve_modules_from_names(
|
||||||
|
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
|
||||||
|
)
|
||||||
|
for module in unintialized_modules:
|
||||||
|
model._init_weights(module)
|
||||||
# copy state_dict so _load_from_state_dict can modify it
|
# copy state_dict so _load_from_state_dict can modify it
|
||||||
metadata = getattr(state_dict, "_metadata", None)
|
metadata = getattr(state_dict, "_metadata", None)
|
||||||
state_dict = state_dict.copy()
|
state_dict = state_dict.copy()
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
state_dict._metadata = metadata
|
state_dict._metadata = metadata
|
||||||
|
|
||||||
|
error_msgs = []
|
||||||
|
|
||||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||||
# so we need to apply the function recursively.
|
# so we need to apply the function recursively.
|
||||||
def load(module: nn.Module, prefix=""):
|
def load(module: nn.Module, prefix=""):
|
||||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
@@ -1218,7 +1308,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Make sure we are able to load base models as well as derived models (with heads)
|
# Make sure we are able to load base models as well as derived models (with heads)
|
||||||
start_prefix = ""
|
start_prefix = ""
|
||||||
model_to_load = model
|
model_to_load = model
|
||||||
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
|
|
||||||
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
||||||
start_prefix = cls.base_model_prefix + "."
|
start_prefix = cls.base_model_prefix + "."
|
||||||
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
||||||
@@ -1226,23 +1315,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
load(model_to_load, prefix=start_prefix)
|
load(model_to_load, prefix=start_prefix)
|
||||||
|
|
||||||
if model.__class__.__name__ != model_to_load.__class__.__name__:
|
|
||||||
base_model_state_dict = model_to_load.state_dict().keys()
|
|
||||||
head_model_state_dict_without_base_prefix = [
|
|
||||||
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
|
||||||
]
|
|
||||||
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
|
||||||
|
|
||||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
|
||||||
# the user.
|
|
||||||
if cls._keys_to_ignore_on_load_missing is not None:
|
|
||||||
for pat in cls._keys_to_ignore_on_load_missing:
|
|
||||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
|
||||||
|
|
||||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
|
||||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
|
||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||||
@@ -1269,21 +1341,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
error_msg = "\n\t".join(error_msgs)
|
error_msg = "\n\t".join(error_msgs)
|
||||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
||||||
# make sure token embedding weights are still tied if needed
|
|
||||||
model.tie_weights()
|
|
||||||
|
|
||||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
return model, missing_keys, unexpected_keys, error_msgs
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if output_loading_info:
|
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
|
||||||
loading_info = {
|
module_keys = set([".".join(key.split(".")[:-1]) for key in names])
|
||||||
"missing_keys": missing_keys,
|
|
||||||
"unexpected_keys": unexpected_keys,
|
|
||||||
"error_msgs": error_msgs,
|
|
||||||
}
|
|
||||||
return model, loading_info
|
|
||||||
|
|
||||||
return model
|
retrieved_modules = []
|
||||||
|
# retrieve all modules that has at least one missing weight name
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if remove_prefix:
|
||||||
|
name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name
|
||||||
|
elif add_prefix:
|
||||||
|
name = ".".join([self.base_model_prefix, name])
|
||||||
|
|
||||||
|
if name in module_keys:
|
||||||
|
retrieved_modules.append(module)
|
||||||
|
|
||||||
|
return retrieved_modules
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
class Conv1D(nn.Module):
|
||||||
|
|||||||
@@ -177,6 +177,103 @@ class ModelTesterMixin:
|
|||||||
for k in _keys_to_ignore_on_save:
|
for k in _keys_to_ignore_on_save:
|
||||||
self.assertNotIn(k, state_dict_saved)
|
self.assertNotIn(k, state_dict_saved)
|
||||||
|
|
||||||
|
def _mock_init_weights(self, module):
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
module.weight.data.fill_(3)
|
||||||
|
if hasattr(module, "bias") and module.bias is not None:
|
||||||
|
module.bias.data.fill_(3)
|
||||||
|
|
||||||
|
def test_save_load_fast_init_from_base(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
base_class = MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
|
if isinstance(base_class, tuple):
|
||||||
|
base_class = base_class[0]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class == base_class:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# make a copy of model class to not break future tests
|
||||||
|
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
|
||||||
|
class CopyClass(model_class):
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_class_copy = CopyClass
|
||||||
|
|
||||||
|
# make sure that all keys are expected for test
|
||||||
|
model_class_copy._keys_to_ignore_on_load_missing = []
|
||||||
|
|
||||||
|
# make init deterministic, but make sure that
|
||||||
|
# non-initialized weights throw errors nevertheless
|
||||||
|
model_class_copy._init_weights = self._mock_init_weights
|
||||||
|
|
||||||
|
model = base_class(config)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
# this will often delete a single weight of a multi-weight module
|
||||||
|
# to test an edge case
|
||||||
|
random_key_to_del = random.choice(list(state_dict.keys()))
|
||||||
|
del state_dict[random_key_to_del]
|
||||||
|
|
||||||
|
# check that certain keys didn't get saved with the model
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||||
|
|
||||||
|
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
|
||||||
|
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||||
|
|
||||||
|
for key in model_fast_init.state_dict().keys():
|
||||||
|
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
def test_save_load_fast_init_to_base(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
base_class = MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
|
if isinstance(base_class, tuple):
|
||||||
|
base_class = base_class[0]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
|
||||||
|
if model_class == base_class:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# make a copy of model class to not break future tests
|
||||||
|
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
|
||||||
|
class CopyClass(base_class):
|
||||||
|
pass
|
||||||
|
|
||||||
|
base_class_copy = CopyClass
|
||||||
|
|
||||||
|
# make sure that all keys are expected for test
|
||||||
|
base_class_copy._keys_to_ignore_on_load_missing = []
|
||||||
|
|
||||||
|
# make init deterministic, but make sure that
|
||||||
|
# non-initialized weights throw errors nevertheless
|
||||||
|
base_class_copy._init_weights = self._mock_init_weights
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
# this will often delete a single weight of a multi-weight module
|
||||||
|
# to test an edge case
|
||||||
|
random_key_to_del = random.choice(list(state_dict.keys()))
|
||||||
|
del state_dict[random_key_to_del]
|
||||||
|
|
||||||
|
# check that certain keys didn't get saved with the model
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.config.save_pretrained(tmpdirname)
|
||||||
|
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||||
|
|
||||||
|
model_fast_init = base_class_copy.from_pretrained(tmpdirname)
|
||||||
|
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||||
|
|
||||||
|
for key in model_fast_init.state_dict().keys():
|
||||||
|
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -400,6 +400,18 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
|
# overwrite from test_modeling_common
|
||||||
|
def _mock_init_weights(self, module):
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
module.weight.data.fill_(3)
|
||||||
|
if hasattr(module, "bias") and module.bias is not None:
|
||||||
|
module.bias.data.fill_(3)
|
||||||
|
|
||||||
|
for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]:
|
||||||
|
if hasattr(module, param) and getattr(module, param) is not None:
|
||||||
|
weight = getattr(module, param)
|
||||||
|
weight.data.fill_(3)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
|
class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
@@ -443,6 +455,18 @@ class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
loss = model(**inputs).loss
|
loss = model(**inputs).loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
# overwrite from test_modeling_common
|
||||||
|
def _mock_init_weights(self, module):
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
module.weight.data.fill_(3)
|
||||||
|
if hasattr(module, "bias") and module.bias is not None:
|
||||||
|
module.bias.data.fill_(3)
|
||||||
|
|
||||||
|
for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]:
|
||||||
|
if hasattr(module, param) and getattr(module, param) is not None:
|
||||||
|
weight = getattr(module, param)
|
||||||
|
weight.data.fill_(3)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
|||||||
@@ -348,6 +348,31 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
|||||||
[expected_shape] * len(iter_hidden_states),
|
[expected_shape] * len(iter_hidden_states),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# overwrite from test_modeling_common
|
||||||
|
def _mock_init_weights(self, module):
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
module.weight.data.fill_(3)
|
||||||
|
if hasattr(module, "cluster_weight") and module.cluster_weight is not None:
|
||||||
|
module.cluster_weight.data.fill_(3)
|
||||||
|
if hasattr(module, "bias") and module.bias is not None:
|
||||||
|
module.bias.data.fill_(3)
|
||||||
|
if hasattr(module, "cluster_bias") and module.cluster_bias is not None:
|
||||||
|
module.cluster_bias.data.fill_(3)
|
||||||
|
|
||||||
|
if hasattr(module, "emb_projs"):
|
||||||
|
for i in range(len(module.emb_projs)):
|
||||||
|
if module.emb_projs[i] is not None:
|
||||||
|
torch.nn.init.constant_(module.emb_projs[i], 0.0003)
|
||||||
|
if hasattr(module, "out_projs"):
|
||||||
|
for i in range(len(module.out_projs)):
|
||||||
|
if module.out_projs[i] is not None:
|
||||||
|
torch.nn.init.constant_(module.out_projs[i], 0.0003)
|
||||||
|
|
||||||
|
for param in ["r_emb", "r_w_bias", "r_r_bias", "r_bias"]:
|
||||||
|
if hasattr(module, param) and getattr(module, param) is not None:
|
||||||
|
weight = getattr(module, param)
|
||||||
|
weight.data.fill_(3)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -329,6 +329,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# overwrite from test_modeling_common
|
||||||
|
def _mock_init_weights(self, module):
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
module.weight.data.fill_(3)
|
||||||
|
if hasattr(module, "weight_g") and module.weight is not None:
|
||||||
|
module.weight_g.data.fill_(3)
|
||||||
|
if hasattr(module, "bias") and module.bias is not None:
|
||||||
|
module.bias.data.fill_(3)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
@@ -446,6 +455,15 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# overwrite from test_modeling_common
|
||||||
|
def _mock_init_weights(self, module):
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
module.weight.data.fill_(3)
|
||||||
|
if hasattr(module, "weight_g") and module.weight is not None:
|
||||||
|
module.weight_g.data.fill_(3)
|
||||||
|
if hasattr(module, "bias") and module.bias is not None:
|
||||||
|
module.bias.data.fill_(3)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
|||||||
@@ -594,6 +594,18 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
|||||||
# xlnet cannot keep gradients in attentions or hidden states
|
# xlnet cannot keep gradients in attentions or hidden states
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# overwrite from test_modeling_common
|
||||||
|
def _mock_init_weights(self, module):
|
||||||
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
module.weight.data.fill_(3)
|
||||||
|
if hasattr(module, "bias") and module.bias is not None:
|
||||||
|
module.bias.data.fill_(3)
|
||||||
|
|
||||||
|
for param in ["q", "k", "v", "o", "r", "r_r_bias", "r_s_bias", "r_w_bias", "seg_embed", "mask_emb"]:
|
||||||
|
if hasattr(module, param) and getattr(module, param) is not None:
|
||||||
|
weight = getattr(module, param)
|
||||||
|
weight.data.fill_(3)
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
def _check_hidden_states_for_generate(
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user