From 3e3e41ae20be8e0fd9dd077b61be724f1098d407 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 May 2021 17:22:20 +0200 Subject: [PATCH] Pytorch - Lazy initialization of models (#11471) * lazy_init_weights * remove ipdb * save int * add necessary code * remove unnecessary utils * Update src/transformers/models/t5/modeling_t5.py * clean * add tests * correct * finish tests * finish tests * fix some more tests * fix xlnet & transfo-xl * fix more tests * make sure tests are independent * fix tests more * finist tests * final touches * Update src/transformers/modeling_utils.py * Apply suggestions from code review * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman * clean tests * give arg positive name * add more mock weights to xlnet Co-authored-by: Stas Bekman --- examples/pytorch/test_examples.py | 1 + src/transformers/modeling_utils.py | 309 ++++++++++++++++++----------- tests/test_modeling_common.py | 97 +++++++++ tests/test_modeling_funnel.py | 24 +++ tests/test_modeling_transfo_xl.py | 25 +++ tests/test_modeling_wav2vec2.py | 18 ++ tests/test_modeling_xlnet.py | 12 ++ 7 files changed, 369 insertions(+), 117 deletions(-) mode change 100755 => 100644 src/transformers/modeling_utils.py diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index 5d4f0c24c1..717bca47c6 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -195,6 +195,7 @@ class ExamplesTests(TestCasePlus): --per_device_train_batch_size=2 --per_device_eval_batch_size=2 --num_train_epochs={epochs} + --seed 7 """.split() if torch_device != "cuda": diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py old mode 100755 new mode 100644 index ee81a3adf1..8160b4ba37 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -18,6 +18,7 @@ import inspect import os import re import warnings +from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -50,6 +51,26 @@ from .utils import logging logger = logging.get_logger(__name__) + +_init_weights = True + + +@contextmanager +def no_init_weights(_enable=True): + """ + Context manager to globally disable weight initialization to speed up loading large models. + + TODO(Patrick): Delete safety argument `_enable=True` at next major version. . + """ + global _init_weights + if _enable: + _init_weights = False + try: + yield + finally: + _init_weights = True + + try: from torch.nn import Identity except ImportError: @@ -768,17 +789,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix def init_weights(self): """ - Initializes and prunes weights if needed. + If needed prunes and maybe initializes weights. """ - # Initialize weights - self.apply(self._init_weights) - # Prune heads if needed if self.config.pruned_heads: self.prune_heads(self.config.pruned_heads) - # Tie weights if needed - self.tie_weights() + if _init_weights: + # Initialize weights + self.apply(self._init_weights) + + # Tie weights should be skipped when not initializing all weights + # since from_pretrained(...) calls tie weights anyways + self.tie_weights() def prune_heads(self, heads_to_prune: Dict[int, List[int]]): """ @@ -956,6 +979,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. + _fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`): + Whether or not to disable fast initialization. + + .. warning:: + + One should only disable `_fast_init` to ensure backwards compatibility with + ``transformers.__version__ < 4.6.0`` for seeded model initialization. This argument will be removed + at the next major version. See `pull request 11471 + `__ for more information. + kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or @@ -1012,6 +1045,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix mirror = kwargs.pop("mirror", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) + _fast_init = kwargs.pop("_fast_init", True) user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} if from_pipeline is not None: @@ -1119,7 +1153,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix config.name_or_path = pretrained_model_name_or_path # Instantiate model. - if is_deepspeed_zero3_enabled(): import deepspeed @@ -1127,23 +1160,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # this immediately partitions the model across all gpus, to avoid the overhead in time # and memory copying it on CPU or each GPU first with deepspeed.zero.Init(config=deepspeed_config()): - model = cls(config, *model_args, **model_kwargs) + with no_init_weights(_enable=_fast_init): + model = cls(config, *model_args, **model_kwargs) else: - model = cls(config, *model_args, **model_kwargs) - - if state_dict is None and not (from_tf or from_flax): - try: - state_dict = torch.load(resolved_archive_file, map_location="cpu") - except Exception: - raise OSError( - f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " - f"at '{resolved_archive_file}'" - "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " - ) - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] + with no_init_weights(_enable=_fast_init): + model = cls(config, *model_args, **model_kwargs) if from_tf: if resolved_archive_file.endswith(".index"): @@ -1173,102 +1194,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) raise else: - # Convert old format to new format if needed from a PyTorch state_dict - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if "gamma" in key: - new_key = key.replace("gamma", "weight") - if "beta" in key: - new_key = key.replace("beta", "bias") - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) + if state_dict is None: + try: + state_dict = torch.load(resolved_archive_file, map_location="cpu") + except Exception: + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " + f"at '{resolved_archive_file}'" + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " + ) - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, "_metadata", None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata + model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model( + model, state_dict, pretrained_model_name_or_path + ) - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants - # so we need to apply the function recursively. - def load(module: nn.Module, prefix=""): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) - if is_deepspeed_zero3_enabled(): - import deepspeed - - # because zero3 puts placeholders in model params, this context - # manager gathers (unpartitions) the params of the current layer, then loads from - # the state dict and then re-partitions them again - with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): - if torch.distributed.get_rank() == 0: - module._load_from_state_dict(*args) - else: - module._load_from_state_dict(*args) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - # Make sure we are able to load base models as well as derived models (with heads) - start_prefix = "" - model_to_load = model - has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()) - if not hasattr(model, cls.base_model_prefix) and has_prefix_module: - start_prefix = cls.base_model_prefix + "." - if hasattr(model, cls.base_model_prefix) and not has_prefix_module: - model_to_load = getattr(model, cls.base_model_prefix) - - load(model_to_load, prefix=start_prefix) - - if model.__class__.__name__ != model_to_load.__class__.__name__: - base_model_state_dict = model_to_load.state_dict().keys() - head_model_state_dict_without_base_prefix = [ - key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() - ] - missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) - - # Some models may have keys that are not in the state by design, removing them before needlessly warning - # the user. - if cls._keys_to_ignore_on_load_missing is not None: - for pat in cls._keys_to_ignore_on_load_missing: - missing_keys = [k for k in missing_keys if re.search(pat, k) is None] - - if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " - f"initializing {model.__class__.__name__}: {unexpected_keys}\n" - f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " - f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" - f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " - f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." - ) - else: - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " - f"and are newly initialized: {missing_keys}\n" - f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - else: - logger.info( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" - f"If your task is similar to the task the model of the checkpoint was trained on, " - f"you can already use {model.__class__.__name__} for predictions without further training." - ) - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") # make sure token embedding weights are still tied if needed model.tie_weights() @@ -1285,6 +1224,142 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return model + @classmethod + def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path): + + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + # Retrieve missing & unexpected_keys + expected_keys = list(model.state_dict().keys()) + loaded_keys = list(state_dict.keys()) + prefix = model.base_model_prefix + + has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) + expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) + remove_prefix = not has_prefix_module and expects_prefix_module + add_prefix = has_prefix_module and not expects_prefix_module + + if remove_prefix: + expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys] + elif add_prefix: + expected_keys = [".".join([prefix, s]) for s in expected_keys] + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + # tie unintialized modules + unintialized_modules = model.retrieve_modules_from_names( + missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix + ) + for module in unintialized_modules: + model._init_weights(module) + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + if is_deepspeed_zero3_enabled(): + import deepspeed + + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + else: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + if not hasattr(model, cls.base_model_prefix) and has_prefix_module: + start_prefix = cls.base_model_prefix + "." + if hasattr(model, cls.base_model_prefix) and not has_prefix_module: + model_to_load = getattr(model, cls.base_model_prefix) + + load(model_to_load, prefix=start_prefix) + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " + f"initializing {model.__class__.__name__}: {unexpected_keys}\n" + f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" + f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " + f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"and are newly initialized: {missing_keys}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" + f"If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {model.__class__.__name__} for predictions without further training." + ) + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + return model, missing_keys, unexpected_keys, error_msgs + + def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): + module_keys = set([".".join(key.split(".")[:-1]) for key in names]) + + retrieved_modules = [] + # retrieve all modules that has at least one missing weight name + for name, module in self.named_modules(): + if remove_prefix: + name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name + elif add_prefix: + name = ".".join([self.base_model_prefix, name]) + + if name in module_keys: + retrieved_modules.append(module) + + return retrieved_modules + class Conv1D(nn.Module): """ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f83d65b51a..a98d406d2f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -177,6 +177,103 @@ class ModelTesterMixin: for k in _keys_to_ignore_on_save: self.assertNotIn(k, state_dict_saved) + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + + def test_save_load_fast_init_from_base(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + base_class = MODEL_MAPPING[config.__class__] + + if isinstance(base_class, tuple): + base_class = base_class[0] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + # make a copy of model class to not break future tests + # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class + class CopyClass(model_class): + pass + + model_class_copy = CopyClass + + # make sure that all keys are expected for test + model_class_copy._keys_to_ignore_on_load_missing = [] + + # make init deterministic, but make sure that + # non-initialized weights throw errors nevertheless + model_class_copy._init_weights = self._mock_init_weights + + model = base_class(config) + state_dict = model.state_dict() + + # this will often delete a single weight of a multi-weight module + # to test an edge case + random_key_to_del = random.choice(list(state_dict.keys())) + del state_dict[random_key_to_del] + + # check that certain keys didn't get saved with the model + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) + + model_fast_init = model_class_copy.from_pretrained(tmpdirname) + model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False) + + for key in model_fast_init.state_dict().keys(): + max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + def test_save_load_fast_init_to_base(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + base_class = MODEL_MAPPING[config.__class__] + + if isinstance(base_class, tuple): + base_class = base_class[0] + + for model_class in self.all_model_classes: + + if model_class == base_class: + continue + + # make a copy of model class to not break future tests + # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class + class CopyClass(base_class): + pass + + base_class_copy = CopyClass + + # make sure that all keys are expected for test + base_class_copy._keys_to_ignore_on_load_missing = [] + + # make init deterministic, but make sure that + # non-initialized weights throw errors nevertheless + base_class_copy._init_weights = self._mock_init_weights + + model = model_class(config) + state_dict = model.state_dict() + + # this will often delete a single weight of a multi-weight module + # to test an edge case + random_key_to_del = random.choice(list(state_dict.keys())) + del state_dict[random_key_to_del] + + # check that certain keys didn't get saved with the model + with tempfile.TemporaryDirectory() as tmpdirname: + model.config.save_pretrained(tmpdirname) + torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) + + model_fast_init = base_class_copy.from_pretrained(tmpdirname) + model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False) + + for key in model_fast_init.state_dict().keys(): + max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_funnel.py b/tests/test_modeling_funnel.py index 9be00caeb7..c7f8f7bf0e 100644 --- a/tests/test_modeling_funnel.py +++ b/tests/test_modeling_funnel.py @@ -400,6 +400,18 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + + for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]: + if hasattr(module, param) and getattr(module, param) is not None: + weight = getattr(module, param) + weight.data.fill_(3) + @require_torch class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase): @@ -443,6 +455,18 @@ class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase): loss = model(**inputs).loss loss.backward() + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + + for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]: + if hasattr(module, param) and getattr(module, param) is not None: + weight = getattr(module, param) + weight.data.fill_(3) + @require_torch @require_sentencepiece diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index 6f771ece01..adbaf3642e 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -348,6 +348,31 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC [expected_shape] * len(iter_hidden_states), ) + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "cluster_weight") and module.cluster_weight is not None: + module.cluster_weight.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + if hasattr(module, "cluster_bias") and module.cluster_bias is not None: + module.cluster_bias.data.fill_(3) + + if hasattr(module, "emb_projs"): + for i in range(len(module.emb_projs)): + if module.emb_projs[i] is not None: + torch.nn.init.constant_(module.emb_projs[i], 0.0003) + if hasattr(module, "out_projs"): + for i in range(len(module.out_projs)): + if module.out_projs[i] is not None: + torch.nn.init.constant_(module.out_projs[i], 0.0003) + + for param in ["r_emb", "r_w_bias", "r_r_bias", "r_bias"]: + if hasattr(module, param) and getattr(module, param) is not None: + weight = getattr(module, param) + weight.data.fill_(3) + @require_torch class TransfoXLModelLanguageGenerationTest(unittest.TestCase): diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index abb57eb9af..f2bb897e55 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -329,6 +329,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "weight_g") and module.weight is not None: + module.weight_g.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + @slow def test_model_from_pretrained(self): model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") @@ -446,6 +455,15 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "weight_g") and module.weight is not None: + module.weight_g.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + @slow def test_model_from_pretrained(self): model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 93031d0371..2ab4940689 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -594,6 +594,18 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) # xlnet cannot keep gradients in attentions or hidden states return + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + + for param in ["q", "k", "v", "o", "r", "r_r_bias", "r_s_bias", "r_w_bias", "seg_embed", "mask_emb"]: + if hasattr(module, param) and getattr(module, param) is not None: + weight = getattr(module, param) + weight.data.fill_(3) + def _check_hidden_states_for_generate( self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 ):