From 04b2f13c37791204b02178392671d9dae52065be Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 9 Feb 2023 15:46:26 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=F0=9F=9A=A8=20Enforce?= =?UTF-8?q?=20single=20model=20initialization=20(#21431)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Enforce single model initialization * Add OneFormer example for problem 3 * Do it the Stas way * Actually rename the uses... * Rewrite test * Try to change the test this way * Fix all init slow/fast tests * Break connection * Fix more tests * Fix test for initialization * Remove custom test * Quality * Fix last failing tests * The end? --- docs/source/en/add_new_model.mdx | 42 +++++++++++++++ src/transformers/modeling_utils.py | 41 ++++++++++++--- .../models/altclip/modeling_altclip.py | 2 + src/transformers/models/bart/modeling_bart.py | 8 +-- .../modeling_bigbird_pegasus.py | 8 +-- src/transformers/models/fsmt/modeling_fsmt.py | 3 ++ src/transformers/models/led/modeling_led.py | 8 +-- .../maskformer/modeling_maskformer_swin.py | 3 ++ .../models/mbart/modeling_mbart.py | 8 +-- src/transformers/models/mvp/modeling_mvp.py | 7 +-- .../models/oneformer/modeling_oneformer.py | 1 + .../models/plbart/modeling_plbart.py | 5 +- .../models/upernet/modeling_upernet.py | 6 +++ .../models/wav2vec2/modeling_wav2vec2.py | 15 ++++-- .../modeling_wav2vec2_conformer.py | 15 ++++-- .../models/wavlm/modeling_wavlm.py | 1 - tests/models/bart/test_modeling_bart.py | 3 ++ tests/models/deta/test_modeling_deta.py | 24 +++++---- tests/models/dpt/test_modeling_dpt.py | 25 ++++++++- tests/models/dpt/test_modeling_dpt_hybrid.py | 25 ++++++++- .../layoutlmv2/test_modeling_layoutlmv2.py | 52 ------------------- .../prophetnet/test_modeling_prophetnet.py | 4 +- .../models/reformer/test_modeling_reformer.py | 19 +++---- .../vit_hybrid/test_modeling_vit_hybrid.py | 24 ++++++++- tests/test_modeling_common.py | 51 +++++++++++++----- 25 files changed, 277 insertions(+), 123 deletions(-) diff --git a/docs/source/en/add_new_model.mdx b/docs/source/en/add_new_model.mdx index 60dda2b02c..56a130f14e 100644 --- a/docs/source/en/add_new_model.mdx +++ b/docs/source/en/add_new_model.mdx @@ -492,6 +492,48 @@ model = BrandNewBertModel(BrandNewBertConfig()) The above command will create a model according to the default parameters as defined in `BrandNewBertConfig()` with random weights, thus making sure that the `init()` methods of all components works. +Note that all random initialization should happen in the `_init_weights` method of your `BrandnewBertPreTrainedModel` +class. It should initialize all leaf modules depending on the variables of the config. Here is an example with the +BERT `_init_weights` method: + +```py +def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) +``` + +You can have some more custom schemes if you need a special initialization for some modules. For instance, in +`Wav2Vec2ForPreTraining`, the last two linear layers need to have the initialization of the regular PyTorch `nn.Linear` +but all the other ones should use an initialization as above. This is coded like this: + +```py +def _init_weights(self, module): + """Initialize the weights""" + if isinstnace(module, Wav2Vec2ForPreTraining): + module.project_hid.reset_parameters() + module.project_q.reset_parameters() + module.project_hid._is_hf_initialized = True + module.project_q._is_hf_initialized = True + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() +``` + +The `_is_hf_initialized` flag is internally used to make sure we only initialize a submodule once. By setting it to +`True` for `module.project_q` and `module.project_hid`, we make sure the custom initialization we did is not overridden later on, +the `_init_weights` function won't be applied to them. + **6. Write a conversion script** Next, you should write a conversion script that lets you convert the checkpoint you used to debug *brand_new_bert* in diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2b74546418..160cf814b8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -436,6 +436,17 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) +def set_initialized_submodules(model, state_dict_keys): + """ + Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state + dict. + """ + for module_name, module in model.named_modules(): + loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")] + if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: + module._is_hf_initialized = True + + def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] @@ -1176,7 +1187,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix """ Initialize the weights. This method should be overridden by derived class. """ - raise NotImplementedError(f"Make sure `_init_weights` is implemented for {self.__class__}") + pass + + def _initialize_weights(self, module): + """ + Initialize the weights if they are not already initialized. + """ + if getattr(module, "_is_hf_initialized", False): + return + self._init_weights(module) + module._is_hf_initialized = True def tie_weights(self): """ @@ -1505,7 +1525,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix def init_weights(self): """ - If needed prunes and maybe initializes weights. + If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any + initialization logic in `_init_weights`. """ # Prune heads if needed if self.config.pruned_heads: @@ -1513,7 +1534,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if _init_weights: # Initialize weights - self.apply(self._init_weights) + self.apply(self._initialize_weights) # Tie weights should be skipped when not initializing all weights # since from_pretrained(...) calls tie weights anyways @@ -2713,11 +2734,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. if _fast_init: - uninitialized_modules = model.retrieve_modules_from_names( - missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model - ) - for module in uninitialized_modules: - model._init_weights(module) + if remove_prefix_from_model: + _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] + elif add_prefix_to_model: + _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] + else: + _loaded_keys = loaded_keys + set_initialized_submodules(model, _loaded_keys) + # This will only initialize submodules that are not marked as initialized by the line above. + model.apply(model._initialize_weights) # Set some modules to fp32 if any if keep_in_fp32_modules is not None: diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 7d150d5734..8f05e71a46 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -1067,10 +1067,12 @@ class AltCLIPPreTrainedModel(PreTrainedModel): module.text_projection.weight, std=module.text_embed_dim**-0.5 * self.config.initializer_factor, ) + module.text_projection._is_hf_initialized = True nn.init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, ) + module.visual_projection._is_hf_initialized = True elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 903231cc06..1f2c4a14ed 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1473,8 +1473,9 @@ class BartForSequenceClassification(BartPretrainedModel): config.num_labels, config.classifier_dropout, ) - self.model._init_weights(self.classification_head.dense) - self.model._init_weights(self.classification_head.out_proj) + + # Initialize weights and apply final processing + self.post_init() @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -1601,7 +1602,8 @@ class BartForQuestionAnswering(BartPretrainedModel): self.model = BartModel(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - self.model._init_weights(self.qa_outputs) + # Initialize weights and apply final processing + self.post_init() @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_code_sample_docstrings( diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index bb91e62b03..ddbb0420c5 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2658,8 +2658,9 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): config.num_labels, config.classifier_dropout, ) - self.model._init_weights(self.classification_head.dense) - self.model._init_weights(self.classification_head.out_proj) + + # Initialize weights and apply final processing + self.post_init() @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -2785,7 +2786,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): self.model = BigBirdPegasusModel(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - self.model._init_weights(self.qa_outputs) + # Initialize weights and apply final processing + self.post_init() @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) @add_code_sample_docstrings( diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 4ad4dec842..4bc01b6fef 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -1186,6 +1186,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): base_model = FSMTModel(config) self.model = base_model + # Initialize weights and apply final processing + self.post_init() + @add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @add_end_docstrings(FSMT_GENERATION_EXAMPLE) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 5181be6b83..6c3314b533 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2543,8 +2543,9 @@ class LEDForSequenceClassification(LEDPreTrainedModel): config.num_labels, config.classifier_dropout, ) - self.led._init_weights(self.classification_head.dense) - self.led._init_weights(self.classification_head.out_proj) + + # Initialize weights and apply final processing + self.post_init() @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -2672,7 +2673,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel): self.led = LEDModel(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - self.led._init_weights(self.qa_outputs) + # Initialize weights and apply final processing + self.post_init() @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @add_code_sample_docstrings( diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index f3c5577ab8..91a61712ff 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -866,6 +866,9 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels]) + # Initialize weights and apply final processing + self.post_init() + @property def channels(self): return [self.out_feature_channels[name] for name in self.out_features] diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 688fc9fc9c..122dceaae4 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1447,8 +1447,9 @@ class MBartForSequenceClassification(MBartPreTrainedModel): config.num_labels, config.classifier_dropout, ) - self.model._init_weights(self.classification_head.dense) - self.model._init_weights(self.classification_head.out_proj) + + # Initialize weights and apply final processing + self.post_init() @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -1574,7 +1575,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): self.model = MBartModel(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - self.model._init_weights(self.qa_outputs) + # Initialize weights and apply final processing + self.post_init() @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_code_sample_docstrings( diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index b6a6a9c328..34650b2cb5 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1610,8 +1610,8 @@ class MvpForSequenceClassification(MvpPreTrainedModel): config.classifier_dropout, ) - self.model._init_weights(self.classification_head.dense) - self.model._init_weights(self.classification_head.out_proj) + # Initialize weights and apply final processing + self.post_init() def set_lightweight_tuning(self): self.model.set_lightweight_tuning() @@ -1737,7 +1737,8 @@ class MvpForQuestionAnswering(MvpPreTrainedModel): self.model = MvpModel(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - self.model._init_weights(self.qa_outputs) + # Initialize weights and apply final processing + self.post_init() def set_lightweight_tuning(self): self.model.set_lightweight_tuning() diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 84539b83d9..8e41ff8692 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -2801,6 +2801,7 @@ class OneFormerPreTrainedModel(PreTrainedModel): elif isinstance(module, OneFormerTransformerDecoder): nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std) nn.init.constant_(module.query_input_projection.bias, 0) + module.query_input_projection._is_hf_initialized = True elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention): nn.init.constant_(module.sampling_offsets.weight.data, 0.0) thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 12f761c4f9..97ed3a34b9 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1420,8 +1420,9 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel): config.num_labels, config.classifier_dropout, ) - self.model._init_weights(self.classification_head.dense) - self.model._init_weights(self.classification_head.out_proj) + + # Initialize weights and apply final processing + self.post_init() @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) @add_code_sample_docstrings( diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index 79a76f8cd9..1c2d37b0f2 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -301,6 +301,12 @@ class UperNetPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True + def _init_weights(self, module): + if isinstance(module, UperNetPreTrainedModel): + module.backbone.init_weights() + module.decode_head.init_weights() + module.auxiliary_head.init_weights() + def init_weights(self): """Initialize the weights""" self.backbone.init_weights() diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 3aa983f8e5..ab599fbfd8 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1049,8 +1049,14 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" + # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init. + if isinstance(module, Wav2Vec2ForPreTraining): + module.project_hid.reset_parameters() + module.project_q.reset_parameters() + module.project_hid._is_hf_initialized = True + module.project_q._is_hf_initialized = True # gumbel softmax requires special init - if isinstance(module, Wav2Vec2GumbelVectorQuantizer): + elif isinstance(module, Wav2Vec2GumbelVectorQuantizer): module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.bias.data.zero_() nn.init.uniform_(module.codevectors) @@ -1345,13 +1351,12 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): self.quantizer = Wav2Vec2GumbelVectorQuantizer(config) - # Initialize weights and apply final processing - self.post_init() - - # make sure that project_hid & project_q are initialized like normal linear layers self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + # Initialize weights and apply final processing + self.post_init() + def set_gumbel_temperature(self, temperature: int): """ Set the Gumbel softmax temperature to a given value. Only necessary for training diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 12ce81465f..0dfac20c06 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -1089,8 +1089,14 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" + # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init. + if isinstance(module, Wav2Vec2ConformerForPreTraining): + module.project_hid.reset_parameters() + module.project_q.reset_parameters() + module.project_hid._is_hf_initialized = True + module.project_q._is_hf_initialized = True # gumbel softmax requires special init - if isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): + elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.bias.data.zero_() nn.init.uniform_(module.codevectors) @@ -1381,13 +1387,12 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config) - # Initialize weights and apply final processing - self.post_init() - - # make sure that project_hid & project_q are initialized like normal linear layers self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + # Initialize weights and apply final processing + self.post_init() + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature def set_gumbel_temperature(self, temperature: int): """ diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index f4e9305d6b..6347aafab7 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -962,7 +962,6 @@ class WavLMAdapterLayer(nn.Module): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->WavLM, wav2vec2->wavlm class WavLMPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index d6474c372f..b8f045442d 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1496,3 +1496,6 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un def test_retain_grad_hidden_states_attentions(self): # decoder cannot keep gradients return + + def test_save_load_fast_init_from_base(self): + pass diff --git a/tests/models/deta/test_modeling_deta.py b/tests/models/deta/test_modeling_deta.py index 7f1b43b2af..bb3d38b66f 100644 --- a/tests/models/deta/test_modeling_deta.py +++ b/tests/models/deta/test_modeling_deta.py @@ -410,17 +410,23 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): configs_no_init = _config_zero_init(config) for model_class in self.all_model_classes: model = model_class(config=configs_no_init) + # Skip the check for the backbone + for name, module in model.named_modules(): + if module.__class__.__name__ == "DetaBackboneWithPositionalEncodings": + backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()] + break + for name, param in model.named_parameters(): if param.requires_grad: - if param.requires_grad: - if ( - "level_embed" in name - or "sampling_offsets.bias" in name - or "value_proj" in name - or "output_proj" in name - or "reference_points" in name - ): - continue + if ( + "level_embed" in name + or "sampling_offsets.bias" in name + or "value_proj" in name + or "output_proj" in name + or "reference_points" in name + or name in backbone_params + ): + continue self.assertIn( ((param.data.mean() * 1e9).round() / 1e9).item(), [0.0, 1.0], diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py index 7393a27364..84c907539b 100644 --- a/tests/models/dpt/test_modeling_dpt.py +++ b/tests/models/dpt/test_modeling_dpt.py @@ -24,7 +24,7 @@ from transformers.models.auto import get_values from transformers.testing_utils import require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor if is_torch_available(): @@ -242,6 +242,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase): loss = model(**inputs).loss loss.backward() + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + # Skip the check for the backbone + backbone_params = [] + for name, module in model.named_modules(): + if module.__class__.__name__ == "DPTViTHybridEmbeddings": + backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()] + break + + for name, param in model.named_parameters(): + if param.requires_grad: + if name in backbone_params: + continue + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + @slow def test_model_from_pretrained(self): for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/dpt/test_modeling_dpt_hybrid.py b/tests/models/dpt/test_modeling_dpt_hybrid.py index 494d595a5a..c98293e961 100644 --- a/tests/models/dpt/test_modeling_dpt_hybrid.py +++ b/tests/models/dpt/test_modeling_dpt_hybrid.py @@ -24,7 +24,7 @@ from transformers.models.auto import get_values from transformers.testing_utils import require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor if is_torch_available(): @@ -256,6 +256,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase): loss = model(**inputs).loss loss.backward() + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + # Skip the check for the backbone + backbone_params = [] + for name, module in model.named_modules(): + if module.__class__.__name__ == "DPTViTHybridEmbeddings": + backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()] + break + + for name, param in model.named_parameters(): + if param.requires_grad: + if name in backbone_params: + continue + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + @slow def test_model_from_pretrained(self): for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[1:]: diff --git a/tests/models/layoutlmv2/test_modeling_layoutlmv2.py b/tests/models/layoutlmv2/test_modeling_layoutlmv2.py index a4578d3534..8c51bc667a 100644 --- a/tests/models/layoutlmv2/test_modeling_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_modeling_layoutlmv2.py @@ -15,9 +15,6 @@ """ Testing suite for the PyTorch LayoutLMv2 model. """ -import os -import random -import tempfile import unittest from transformers.testing_utils import require_detectron2, require_torch, require_torch_multi_gpu, slow, torch_device @@ -31,7 +28,6 @@ if is_torch_available(): import torch from transformers import ( - MODEL_MAPPING, LayoutLMv2Config, LayoutLMv2ForQuestionAnswering, LayoutLMv2ForSequenceClassification, @@ -312,54 +308,6 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) - def test_save_load_fast_init_from_base(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - base_class = MODEL_MAPPING[config.__class__] - - if isinstance(base_class, tuple): - base_class = base_class[0] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - # make a copy of model class to not break future tests - # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class - class CopyClass(model_class): - pass - - model_class_copy = CopyClass - - # make sure that all keys are expected for test - model_class_copy._keys_to_ignore_on_load_missing = [] - - # make init deterministic, but make sure that - # non-initialized weights throw errors nevertheless - model_class_copy._init_weights = self._mock_init_weights - - model = base_class(config) - state_dict = model.state_dict() - - # this will often delete a single weight of a multi-weight module - # to test an edge case - random_key_to_del = random.choice(list(state_dict.keys())) - del state_dict[random_key_to_del] - - # check that certain keys didn't get saved with the model - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) - - model_fast_init = model_class_copy.from_pretrained(tmpdirname) - model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False) - - for key in model_fast_init.state_dict().keys(): - if key == "layoutlmv2.visual_segment_embedding": - # we skip the visual segment embedding as it has a custom initialization scheme - continue - max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True diff --git a/tests/models/prophetnet/test_modeling_prophetnet.py b/tests/models/prophetnet/test_modeling_prophetnet.py index 86cc8f3896..2c98102a0c 100644 --- a/tests/models/prophetnet/test_modeling_prophetnet.py +++ b/tests/models/prophetnet/test_modeling_prophetnet.py @@ -436,10 +436,10 @@ class ProphetNetModelTester: decoder_attention_mask=decoder_attention_mask, labels=lm_labels, ) - self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(4.5819, device=torch_device), atol=1e-3)) + self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(4.5981, device=torch_device), atol=1e-3)) expected_logit_slice = torch.tensor( - [-0.1565, 0.0418, 0.1207, 0.0030, 0.0665, 0.0467, 0.0412], device=torch_device + [-0.0648, 0.0790, 0.0360, 0.0089, 0.0039, -0.0639, 0.0131], device=torch_device ) self.parent.assertTrue(torch.allclose(result.logits[0, :, 1], expected_logit_slice, atol=1e-3)) diff --git a/tests/models/reformer/test_modeling_reformer.py b/tests/models/reformer/test_modeling_reformer.py index 4193607897..4af7b9864b 100644 --- a/tests/models/reformer/test_modeling_reformer.py +++ b/tests/models/reformer/test_modeling_reformer.py @@ -1145,10 +1145,11 @@ class ReformerIntegrationTests(unittest.TestCase): hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] output_slice = hidden_states[1, -1, :5] expected_output_slice = torch.tensor( - [0.0256, -0.0121, 0.0636, 0.0024, -0.0393], + [0.1018, -0.2026, 0.2116, 0.0270, -0.1233], dtype=torch.float, device=torch_device, ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) def test_local_lm_model_grad(self): @@ -1163,25 +1164,25 @@ class ReformerIntegrationTests(unittest.TestCase): input_ids, _ = self._get_input_ids_and_mask() loss = model(input_ids=input_ids, labels=input_ids)[0] - self.assertTrue(torch.allclose(loss, torch.tensor(5.7786, dtype=torch.float, device=torch_device), atol=1e-3)) + self.assertTrue(torch.allclose(loss, torch.tensor(5.8019, dtype=torch.float, device=torch_device), atol=1e-3)) loss.backward() # check last grads to cover all proable errors grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] expected_grad_slice_word = torch.tensor( - [-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], + [-0.0005, -0.0001, -0.0002, -0.0006, -0.0006], dtype=torch.float, device=torch_device, ) grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] expected_grad_slice_pos_fac_1 = torch.tensor( - [0.0037, -1.3793, -1.0231, -1.5230, -2.5306], + [-0.5235, 0.5704, 0.0922, -0.3140, 0.9928], dtype=torch.float, device=torch_device, ) grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] expected_grad_slice_pos_fac_2 = torch.tensor( - [-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], + [1.7960, 1.7668, 0.5593, 0.0907, 1.8342], dtype=torch.float, device=torch_device, ) @@ -1203,24 +1204,24 @@ class ReformerIntegrationTests(unittest.TestCase): input_ids, _ = self._get_input_ids_and_mask() loss = model(input_ids=input_ids, labels=input_ids)[0] - self.assertTrue(torch.allclose(loss, torch.tensor(5.7819, dtype=torch.float, device=torch_device), atol=1e-3)) + self.assertTrue(torch.allclose(loss, torch.tensor(5.7854, dtype=torch.float, device=torch_device), atol=1e-3)) loss.backward() # check last grads to cover all proable errors grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] expected_grad_slice_word = torch.tensor( - [2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04], + [0.0004, 0.0003, 0.0006, -0.0004, 0.0002], dtype=torch.float, device=torch_device, ) grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] expected_grad_slice_pos_fac_1 = torch.tensor( - [-0.0984, 0.6283, 0.4282, 1.2960, 0.6897], + [-0.3792, 0.5593, -1.6993, 0.2033, 0.4131], dtype=torch.float, device=torch_device, ) grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] expected_grad_slice_pos_fac_2 = torch.tensor( - [0.4626, -0.0231, -0.0172, 0.1081, 0.3805], + [-1.4212, -0.3201, -1.1944, 0.1258, 0.2856], dtype=torch.float, device=torch_device, ) diff --git a/tests/models/vit_hybrid/test_modeling_vit_hybrid.py b/tests/models/vit_hybrid/test_modeling_vit_hybrid.py index cf8d4b48e2..27913f28f4 100644 --- a/tests/models/vit_hybrid/test_modeling_vit_hybrid.py +++ b/tests/models/vit_hybrid/test_modeling_vit_hybrid.py @@ -23,7 +23,7 @@ from transformers.testing_utils import require_accelerate, require_torch, requir from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor if is_torch_available(): @@ -198,6 +198,28 @@ class ViTHybridModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + # Skip the check for the backbone + for name, module in model.named_modules(): + if module.__class__.__name__ == "ViTHybridPatchEmbeddings": + backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()] + break + + for name, param in model.named_parameters(): + if param.requires_grad: + if name in backbone_params: + continue + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + @slow def test_model_from_pretrained(self): for model_name in VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f5d6357e93..26711f660d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -69,7 +69,6 @@ from transformers.testing_utils import ( USER, CaptureLogger, TestCasePlus, - is_flaky, is_pt_flax_cross_test, is_pt_tf_cross_test, is_staging_test, @@ -175,6 +174,9 @@ def _config_zero_init(config): for key in configs_no_init.__dict__.keys(): if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: setattr(configs_no_init, key, 1e-10) + if isinstance(getattr(configs_no_init, key, None), PretrainedConfig): + no_init_subconfig = _config_zero_init(getattr(configs_no_init, key)) + setattr(configs_no_init, key, no_init_subconfig) return configs_no_init @@ -182,6 +184,31 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random" TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification" +def _mock_init_weights(self, module): + for name, param in module.named_parameters(recurse=False): + # Use the first letter of the name to get a value and go from a <> -13 to z <> 12 + value = ord(name[0].lower()) - 110 + param.data.fill_(value) + + +def _mock_all_init_weights(self): + # Prune heads if needed + if self.config.pruned_heads: + self.prune_heads(self.config.pruned_heads) + + import transformers.modeling_utils + + if transformers.modeling_utils._init_weights: + for module in self.modules(): + module._is_hf_initialized = False + # Initialize weights + self.apply(self._initialize_weights) + + # Tie weights should be skipped when not initializing all weights + # since from_pretrained(...) calls tie weights anyways + self.tie_weights() + + @require_torch class ModelTesterMixin: model_tester = None @@ -357,15 +384,10 @@ class ModelTesterMixin: model.gradient_checkpointing_disable() self.assertFalse(model.is_gradient_checkpointing) - def _mock_init_weights(self, module): - if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) - if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) - - @is_flaky() def test_save_load_fast_init_from_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if config.__class__ not in MODEL_MAPPING: + return base_class = MODEL_MAPPING[config.__class__] if isinstance(base_class, tuple): @@ -387,7 +409,8 @@ class ModelTesterMixin: # make init deterministic, but make sure that # non-initialized weights throw errors nevertheless - model_class_copy._init_weights = self._mock_init_weights + model_class_copy._init_weights = _mock_init_weights + model_class_copy.init_weights = _mock_all_init_weights model = base_class(config) state_dict = model.state_dict() @@ -404,13 +427,16 @@ class ModelTesterMixin: model_fast_init = model_class_copy.from_pretrained(tmpdirname) model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False) + # Before we test anything for key in model_fast_init.state_dict().keys(): max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical") def test_save_load_fast_init_to_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if config.__class__ not in MODEL_MAPPING: + return base_class = MODEL_MAPPING[config.__class__] if isinstance(base_class, tuple): @@ -432,7 +458,8 @@ class ModelTesterMixin: # make init deterministic, but make sure that # non-initialized weights throw errors nevertheless - base_class_copy._init_weights = self._mock_init_weights + base_class_copy._init_weights = _mock_init_weights + base_class_copy.init_weights = _mock_all_init_weights model = model_class(config) state_dict = model.state_dict() @@ -454,7 +481,7 @@ class ModelTesterMixin: max_diff = torch.max( torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]) ).item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical") def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()