From b16688e96a3b3e1e7a701cd5284a850c696c108e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 22 Jul 2025 16:04:20 +0200 Subject: [PATCH] General weight initialization scheme (#39579) * general + modulars from llama * all modular models * style and fix musicgen * fix * Update configuration_musicgen.py * Update modeling_utils.py --- src/transformers/modeling_utils.py | 39 ++++++++++-- .../models/aimv2/modeling_aimv2.py | 18 +----- .../models/aimv2/modular_aimv2.py | 18 +----- .../models/arcee/modeling_arcee.py | 13 ---- src/transformers/models/aria/modeling_aria.py | 33 ++-------- src/transformers/models/aria/modular_aria.py | 33 ++-------- .../models/aya_vision/modeling_aya_vision.py | 15 ----- .../models/aya_vision/modular_aya_vision.py | 15 ----- .../models/bamba/modeling_bamba.py | 14 +---- .../models/bamba/modular_bamba.py | 14 +---- .../models/biogpt/modeling_biogpt.py | 16 ----- .../models/biogpt/modular_biogpt.py | 16 ----- .../models/bitnet/modeling_bitnet.py | 13 ---- .../models/chameleon/modeling_chameleon.py | 17 ------ .../models/cohere/modeling_cohere.py | 13 ---- .../models/cohere/modular_cohere.py | 16 ----- .../models/cohere2/modeling_cohere2.py | 13 ---- src/transformers/models/csm/modeling_csm.py | 16 +---- src/transformers/models/csm/modular_csm.py | 16 +---- .../deepseek_v2/modeling_deepseek_v2.py | 14 +---- .../models/deepseek_v2/modular_deepseek_v2.py | 14 +---- .../deepseek_v3/modeling_deepseek_v3.py | 16 +---- .../models/deepseek_v3/modular_deepseek_v3.py | 16 +---- src/transformers/models/dia/modeling_dia.py | 13 ---- src/transformers/models/dia/modular_dia.py | 13 ---- .../models/diffllama/modeling_diffllama.py | 14 +---- .../models/diffllama/modular_diffllama.py | 14 +---- src/transformers/models/doge/modeling_doge.py | 13 +--- src/transformers/models/doge/modular_doge.py | 3 +- .../models/dots1/modeling_dots1.py | 16 +---- src/transformers/models/emu3/modeling_emu3.py | 13 ---- src/transformers/models/emu3/modular_emu3.py | 13 ---- .../models/ernie4_5/modeling_ernie4_5.py | 13 ---- .../ernie4_5_moe/modeling_ernie4_5_moe.py | 14 +---- .../ernie4_5_moe/modular_ernie4_5_moe.py | 14 +---- .../models/gemma/modeling_gemma.py | 13 ---- .../models/gemma2/modeling_gemma2.py | 13 ---- .../models/gemma3/modeling_gemma3.py | 15 +---- .../models/gemma3/modular_gemma3.py | 15 +---- .../models/gemma3n/modeling_gemma3n.py | 18 +----- .../models/gemma3n/modular_gemma3n.py | 18 +----- src/transformers/models/glm/modeling_glm.py | 13 ---- src/transformers/models/glm4/modeling_glm4.py | 13 ---- .../models/glm4_moe/modeling_glm4_moe.py | 16 +---- .../models/glm4v/modeling_glm4v.py | 16 ----- .../models/glm4v/modular_glm4v.py | 16 ----- .../models/got_ocr2/modeling_got_ocr2.py | 12 +--- .../models/got_ocr2/modular_got_ocr2.py | 12 +--- .../models/gpt_neox/modeling_gpt_neox.py | 14 ----- .../models/gpt_neox/modular_gpt_neox.py | 14 ----- .../models/granite/modeling_granite.py | 13 ---- .../models/granitemoe/modeling_granitemoe.py | 13 +--- .../modeling_granitemoehybrid.py | 20 +----- .../modular_granitemoehybrid.py | 9 +-- .../modeling_granitemoeshared.py | 13 +--- .../models/helium/modeling_helium.py | 13 ---- .../models/hgnet_v2/modeling_hgnet_v2.py | 10 --- .../models/hgnet_v2/modular_hgnet_v2.py | 10 --- .../models/informer/configuration_informer.py | 1 + .../models/informer/modeling_informer.py | 15 +---- .../models/informer/modular_informer.py | 15 +---- .../models/internvl/modeling_internvl.py | 27 +------- .../models/internvl/modular_internvl.py | 27 +------- .../models/janus/configuration_janus.py | 1 + .../models/janus/modeling_janus.py | 18 ------ .../models/janus/modular_janus.py | 19 +----- src/transformers/models/lfm2/modeling_lfm2.py | 13 ---- src/transformers/models/lfm2/modular_lfm2.py | 13 ---- .../models/lightglue/modeling_lightglue.py | 10 --- .../models/lightglue/modular_lightglue.py | 10 --- .../models/llama/modeling_llama.py | 13 ---- .../models/llava/modeling_llava.py | 14 ----- .../models/minimax/modeling_minimax.py | 13 ---- .../models/mistral/modeling_mistral.py | 13 ---- .../models/mistral3/modeling_mistral3.py | 16 ----- .../models/mistral3/modular_mistral3.py | 16 +---- .../models/mixtral/modeling_mixtral.py | 13 ---- .../models/moonshine/modeling_moonshine.py | 15 ----- .../models/moonshine/modular_moonshine.py | 15 ----- .../models/musicgen/configuration_musicgen.py | 1 + .../models/musicgen/modeling_musicgen.py | 5 +- src/transformers/models/olmo/modeling_olmo.py | 11 ---- src/transformers/models/olmo/modular_olmo.py | 20 ++---- .../models/olmo2/modeling_olmo2.py | 13 ---- .../perception_lm/modeling_perception_lm.py | 14 ----- src/transformers/models/phi/modeling_phi.py | 14 ----- src/transformers/models/phi/modular_phi.py | 19 +----- src/transformers/models/phi3/modeling_phi3.py | 13 ---- .../modeling_phi4_multimodal.py | 29 ++------- .../modular_phi4_multimodal.py | 29 ++------- .../models/plbart/configuration_plbart.py | 6 +- .../models/plbart/modeling_plbart.py | 11 ---- .../models/plbart/modular_plbart.py | 11 ---- .../modeling_prompt_depth_anything.py | 7 --- .../modular_prompt_depth_anything.py | 7 --- .../models/qwen2/modeling_qwen2.py | 13 ---- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 61 ++++++------------- .../qwen2_5_omni/modular_qwen2_5_omni.py | 22 ------- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 13 ---- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 13 +--- .../models/qwen2_vl/modeling_qwen2_vl.py | 16 ----- .../models/qwen3/modeling_qwen3.py | 13 ---- .../models/qwen3_moe/modeling_qwen3_moe.py | 13 ---- src/transformers/models/sam/modeling_sam.py | 15 +---- .../models/sam_hq/modeling_sam_hq.py | 17 +----- .../models/sam_hq/modular_sam_hq.py | 3 +- .../models/smollm3/modeling_smollm3.py | 13 ---- .../models/smolvlm/modeling_smolvlm.py | 22 +++++++ .../models/smolvlm/modular_smolvlm.py | 15 +---- .../models/starcoder2/modeling_starcoder2.py | 14 ----- .../models/starcoder2/modular_starcoder2.py | 17 ------ .../models/t5gemma/modeling_t5gemma.py | 13 +--- .../models/t5gemma/modular_t5gemma.py | 13 +--- .../models/timesfm/modeling_timesfm.py | 18 +----- .../models/timesfm/modular_timesfm.py | 18 +----- .../models/vipllava/modeling_vipllava.py | 14 ----- .../models/zamba2/modeling_zamba2.py | 14 +---- .../models/zamba2/modular_zamba2.py | 14 +---- 118 files changed, 205 insertions(+), 1566 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f4fd894b32..71eb25f74a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2967,12 +2967,41 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH def _init_weights(self, module): """ - Initialize the weights. This method should be overridden by derived class and is - the only initialization method that will be called when loading a checkpoint - using `from_pretrained`. Any attempt to initialize outside of this function - will be useless as the torch.nn.init function are all replaced with skip. + Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex + initialization scheme, it should be overriden by the derived `PreTrainedModel` class. In case a model adds an explicit + `nn.Parameter`, this method should also be overriden in order to initialize it correctly. """ - pass + if hasattr(self.config, "initializer_range"): + std = self.config.initializer_range + else: + # 0.02 is the standard default value accross the library + std = getattr(self.config.get_text_config(), "initializer_range", 0.02) + + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.MultiheadAttention): + # This uses torch's original init + module._reset_parameters() + # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names + # between modelings (because they are prefixed with the model name) + elif ( + isinstance( + module, (nn.LayerNorm, nn.RMSNorm, nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) + ) + or "LayerNorm" in module.__class__.__name__ + or "RMSNorm" in module.__class__.__name__ + ): + # Norms can exist without weights (in which case they are None from torch primitives) + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(1.0) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() def _initialize_weights(self, module): """ diff --git a/src/transformers/models/aimv2/modeling_aimv2.py b/src/transformers/models/aimv2/modeling_aimv2.py index 7b124a64c9..ff1eb62781 100644 --- a/src/transformers/models/aimv2/modeling_aimv2.py +++ b/src/transformers/models/aimv2/modeling_aimv2.py @@ -448,24 +448,12 @@ class Aimv2PreTrainedModel(PreTrainedModel): _supports_flex_attn = True def _init_weights(self, module): - std = ( - self.config.vision_config.initializer_range - if hasattr(self.config, "vision_config") - else self.config.initializer_range - ) - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, Aimv2RMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - elif hasattr(module, "logit_scale"): + super()._init_weights(module) + if hasattr(module, "logit_scale"): if isinstance(module.logit_scale, nn.Parameter): module.logit_scale.data.fill_(math.log(1 / 0.07)) elif isinstance(module, Aimv2AttentionPoolingHead): - module.cls_token.data.normal_(mean=0.0, std=std) + module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring( diff --git a/src/transformers/models/aimv2/modular_aimv2.py b/src/transformers/models/aimv2/modular_aimv2.py index 7c83bf4e2d..ad2c3d78aa 100644 --- a/src/transformers/models/aimv2/modular_aimv2.py +++ b/src/transformers/models/aimv2/modular_aimv2.py @@ -445,24 +445,12 @@ class Aimv2PreTrainedModel(PreTrainedModel): _supports_flex_attn = True def _init_weights(self, module): - std = ( - self.config.vision_config.initializer_range - if hasattr(self.config, "vision_config") - else self.config.initializer_range - ) - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, Aimv2RMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - elif hasattr(module, "logit_scale"): + super()._init_weights(module) + if hasattr(module, "logit_scale"): if isinstance(module.logit_scale, nn.Parameter): module.logit_scale.data.fill_(math.log(1 / 0.07)) elif isinstance(module, Aimv2AttentionPoolingHead): - module.cls_token.data.normal_(mean=0.0, std=std) + module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring( diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 43a02ebd8c..9c7f43a84d 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -324,19 +324,6 @@ class ArceePreTrainedModel(PreTrainedModel): "attentions": ArceeAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, ArceeRMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class ArceeModel(ArceePreTrainedModel): diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 6c5c972b1f..29bbfb16fe 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -639,19 +639,9 @@ class AriaTextPreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, AriaTextRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, AriaGroupedExpertsGemm): - module.weight.data.normal_(mean=0.0, std=std) + super()._init_weights(module) + if isinstance(module, AriaGroupedExpertsGemm): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -672,20 +662,9 @@ class AriaPreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.MultiheadAttention): - # This uses torch's original init - module._reset_parameters() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, AriaProjector): - nn.init.trunc_normal_(module.query, std=std) + super()._init_weights(module) + if isinstance(module, AriaProjector): + nn.init.trunc_normal_(module.query, std=self.config.initializer_range) class AriaTextRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index f6303b4d38..1fb64f50b7 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1294,19 +1294,9 @@ class AriaTextPreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, AriaTextRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, AriaGroupedExpertsGemm): - module.weight.data.normal_(mean=0.0, std=std) + super()._init_weights(module) + if isinstance(module, AriaGroupedExpertsGemm): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) class AriaPreTrainedModel(LlamaPreTrainedModel): @@ -1316,20 +1306,9 @@ class AriaPreTrainedModel(LlamaPreTrainedModel): _supports_attention_backend = True def _init_weights(self, module): - std = self.config.initializer_range - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.MultiheadAttention): - # This uses torch's original init - module._reset_parameters() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, AriaProjector): - nn.init.trunc_normal_(module.query, std=std) + LlamaPreTrainedModel._init_weights(module) + if isinstance(module, AriaProjector): + nn.init.trunc_normal_(module.query, std=self.config.initializer_range) class AriaTextModel(LlamaModel): diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index cb48470687..70c85ab27f 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -100,21 +100,6 @@ class AyaVisionPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True - def _init_weights(self, module): - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - @dataclass @auto_docstring( diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index 58c118d73f..9dcf02547a 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -92,21 +92,6 @@ class AyaVisionMultiModalProjector(nn.Module): class AyaVisionPreTrainedModel(LlavaPreTrainedModel): _supports_static_cache = False - def _init_weights(self, module): - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): pass diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 3548e706a5..7389ac5a86 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1088,18 +1088,8 @@ class BambaPreTrainedModel(PreTrainedModel): _is_stateful = True def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, BambaMixer): + super()._init_weights(module) + if isinstance(module, BambaMixer): module.dt_bias.data.fill_(1.0) module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) module.D.data.fill_(1.0) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 9bfbfd159f..f99faa9ed7 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -816,18 +816,8 @@ class BambaPreTrainedModel(PreTrainedModel): _is_stateful = True def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, BambaMixer): + super()._init_weights(module) + if isinstance(module, BambaMixer): module.dt_bias.data.fill_(1.0) module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) module.D.data.fill_(1.0) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index a873cd6b69..63ad37033e 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -349,22 +349,6 @@ class BioGptPreTrainedModel(PreTrainedModel): _supports_static_cache = True - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 3f63caddea..43717ae447 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -174,22 +174,6 @@ class BioGptPreTrainedModel(PreTrainedModel): _supports_static_cache = True - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index c373d659c7..59bd7b2ef9 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -319,19 +319,6 @@ class BitNetPreTrainedModel(PreTrainedModel): "attentions": BitNetAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, BitNetRMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class BitNetModel(BitNetPreTrainedModel): diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 5d8e6fc210..1e43bbc6ec 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -820,23 +820,6 @@ class ChameleonPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True - def _init_weights(self, module): - std = self.config.initializer_range - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ChameleonRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 1fb91bccaa..15934195ce 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -352,19 +352,6 @@ class CoherePreTrainedModel(PreTrainedModel): "attentions": CohereAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, CohereLayerNorm): - module.weight.data.fill_(1.0) - @auto_docstring class CohereModel(CoherePreTrainedModel): diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index 46a58f9c7f..0e27f4dfc4 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -41,7 +41,6 @@ from ..llama.modeling_llama import ( LlamaForCausalLM, LlamaMLP, LlamaModel, - LlamaPreTrainedModel, LlamaRotaryEmbedding, eager_attention_forward, ) @@ -255,21 +254,6 @@ class CohereDecoderLayer(GradientCheckpointingLayer): return hidden_states -class CoherePreTrainedModel(LlamaPreTrainedModel): - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, CohereLayerNorm): - module.weight.data.fill_(1.0) - - class CohereModel(LlamaModel): def __init__(self, config: CohereConfig): super().__init__(config) diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index f3dc518f92..85faa46d17 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -329,19 +329,6 @@ class Cohere2PreTrainedModel(PreTrainedModel): "attentions": Cohere2Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Cohere2LayerNorm): - module.weight.data.fill_(1.0) - @auto_docstring class Cohere2Model(Cohere2PreTrainedModel): diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 31e9ff3689..3cc477b6dc 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -379,21 +379,11 @@ class CsmPreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, CsmCodebooksHead): + super()._init_weights(module) + if isinstance(module, CsmCodebooksHead): num_codebooks = module.num_codebooks for i in range(num_codebooks - 1): - module.weight.data[i].normal_(mean=0.0, std=std) - elif isinstance(module, CsmRMSNorm): - module.weight.data.fill_(1.0) + module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index 4701612752..60b2612759 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -142,21 +142,11 @@ class CsmPreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, CsmCodebooksHead): + super()._init_weights(module) + if isinstance(module, CsmCodebooksHead): num_codebooks = module.num_codebooks for i in range(num_codebooks - 1): - module.weight.data[i].normal_(mean=0.0, std=std) - elif isinstance(module, CsmRMSNorm): - module.weight.data.fill_(1.0) + module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 595953fd6c..d7dd466e3d 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -467,18 +467,8 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, DeepseekV2RMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, DeepseekV2MoEGate): + super()._init_weights(module) + if isinstance(module, DeepseekV2MoEGate): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py index 8244dd8572..dddb74d1bf 100644 --- a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py @@ -503,18 +503,8 @@ class DeepseekV2DecoderLayer(LlamaDecoderLayer): class DeepseekV2PreTrainedModel(LlamaPreTrainedModel): def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, DeepseekV2RMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, DeepseekV2MoEGate): + LlamaPreTrainedModel._init_weights(module) + if isinstance(module, DeepseekV2MoEGate): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 05171a8359..38d39d03ab 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -506,19 +506,9 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, DeepseekV3RMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, DeepseekV3TopkRouter): - module.weight.data.normal_(mean=0.0, std=std) + super()._init_weights(module) + if isinstance(module, DeepseekV3TopkRouter): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 73f3acc204..ebcd3aed39 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -339,19 +339,9 @@ class DeepseekV3DecoderLayer(LlamaDecoderLayer, nn.Module): class DeepseekV3PreTrainedModel(LlamaPreTrainedModel): def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, DeepseekV3RMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, DeepseekV3TopkRouter): - module.weight.data.normal_(mean=0.0, std=std) + LlamaPreTrainedModel._init_weights(module) + if isinstance(module, DeepseekV3TopkRouter): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) class DeepseekV3Model(LlamaModel): diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 2bf05cf683..d115f3766f 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -71,19 +71,6 @@ class DiaPreTrainedModel(PreTrainedModel): main_input_name = "input_ids" _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, DiaRMSNorm): - module.weight.data.fill_(1.0) - class DiaMultiChannelEmbedding(nn.Module): """In order to efficiently compute the audio embedding from the 9 different channels, diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index 8c84d936c5..6934f6e4f7 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -66,19 +66,6 @@ class DiaPreTrainedModel(PreTrainedModel): main_input_name = "input_ids" _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, DiaRMSNorm): - module.weight.data.fill_(1.0) - class DiaMultiChannelEmbedding(nn.Module): """In order to efficiently compute the audio embedding from the 9 different channels, diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 92badf62f2..df4c9d0159 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -542,18 +542,8 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821 - module.weight.data.fill_(1.0) - elif isinstance(module, DiffLlamaAttention): + super()._init_weights(module) + if isinstance(module, DiffLlamaAttention): module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 8091e87ab8..b5034f1749 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -404,18 +404,8 @@ class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): _supports_attention_backend = False def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821 - module.weight.data.fill_(1.0) - elif isinstance(module, DiffLlamaAttention): + LlamaPreTrainedModel._init_weights(module) + if isinstance(module, DiffLlamaAttention): module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 21b2794c03..fccfe71de2 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -504,18 +504,7 @@ class DogePreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, DogeRMSNorm): - module.weight.data.fill_(1.0) - + super()._init_weights(module) if isinstance(module, DogeAttention): if hasattr(module, "A"): module.A.data.zero_() diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index a3d1b4f9bf..49094ec499 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -573,8 +573,7 @@ class DogePreTrainedModel(LlamaPreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - super()._init_weights(module) - + LlamaPreTrainedModel._init_weights(module) if isinstance(module, DogeAttention): if hasattr(module, "A"): module.A.data.zero_() diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 06df98b835..4a90048d7e 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -426,19 +426,9 @@ class Dots1PreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Dots1RMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, Dots1TopkRouter): - module.weight.data.normal_(mean=0.0, std=std) + super()._init_weights(module) + if isinstance(module, Dots1TopkRouter): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index fd7cbf39e1..570c32c2fb 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1103,19 +1103,6 @@ class Emu3PreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True - def _init_weights(self, module): - std = self.config.get_text_config().initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Emu3RMSNorm): # noqa: F821 - module.weight.data.fill_(1.0) - class Emu3RotaryEmbedding(nn.Module): def __init__(self, config: Emu3Config, device=None): diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 580bf670e3..9aa66fe65e 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -845,19 +845,6 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE): _supports_flex_attn = True _supports_attention_backend = True - def _init_weights(self, module): - std = self.config.get_text_config().initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Emu3RMSNorm): # noqa: F821 - module.weight.data.fill_(1.0) - class Emu3TextModel(LlamaModel, Emu3PreTrainedModel): _can_record_outputs = { diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index 0d98583e1b..0a9a170184 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -318,19 +318,6 @@ class Ernie4_5PreTrainedModel(PreTrainedModel): "attentions": Ernie4_5Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Ernie4_5RMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class Ernie4_5Model(Ernie4_5PreTrainedModel): diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 630162b7dc..9081f66265 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -483,18 +483,8 @@ class Ernie4_5_MoEPreTrainedModel(PreTrainedModel): _keep_in_fp32_modules_strict = ["gate", "moe_statics"] def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Ernie4_5_MoERMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, Ernie4_5_MoEStatics): + super()._init_weights(module) + if isinstance(module, Ernie4_5_MoEStatics): module.e_score_correction_bias.data.zero_() diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py index daf122929b..0763e415c5 100644 --- a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -218,18 +218,8 @@ class Ernie4_5_MoEPreTrainedModel(MixtralPreTrainedModel): _keep_in_fp32_modules_strict = ["gate", "moe_statics"] def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Ernie4_5_MoERMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, Ernie4_5_MoEStatics): + MixtralPreTrainedModel._init_weights(module) + if isinstance(module, Ernie4_5_MoEStatics): module.e_score_correction_bias.data.zero_() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 287c9b3013..770a0bbb8c 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -321,19 +321,6 @@ class GemmaPreTrainedModel(PreTrainedModel): "attentions": GemmaAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, GemmaRMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class GemmaModel(GemmaPreTrainedModel): diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 6db1b1f7bb..b31e94b3d9 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -351,19 +351,6 @@ class Gemma2PreTrainedModel(PreTrainedModel): "attentions": Gemma2Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Gemma2RMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class Gemma2Model(Gemma2PreTrainedModel): diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 394e380021..42ec9410b9 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -442,19 +442,8 @@ class Gemma3PreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Gemma3RMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, Gemma3MultiModalProjector): + super()._init_weights(module) + if isinstance(module, Gemma3MultiModalProjector): module.mm_input_projection_weight.data.zero_() diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 57ecedca91..5bf158c00f 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -529,19 +529,8 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel): ] def _init_weights(self, module): - std = self.config.initializer_range - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Gemma3RMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, Gemma3MultiModalProjector): + Gemma2PreTrainedModel._init_weights(module) + if isinstance(module, Gemma3MultiModalProjector): module.mm_input_projection_weight.data.zero_() diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 1411cccef9..f581102639 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1498,22 +1498,8 @@ class Gemma3nPreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - # important: this ported version of Gemma2 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Gemma3nRMSNorm): - if module.with_scale: - module.weight.data.fill_(1.0) - elif isinstance(module, Gemma3nAudioCumulativeGroupNorm): + super()._init_weights(module) + if isinstance(module, Gemma3nAudioCumulativeGroupNorm): module.weight.data.fill_(1.0) elif isinstance(module, Gemma3nAudioAttention): module.per_dim_scale.data.zero_() diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index 8e3bcfd1f1..a9c2dabc76 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1917,22 +1917,8 @@ class Gemma3nPreTrainedModel(Gemma2PreTrainedModel): _no_split_modules = ["Gemma3nTextDecoderLayer"] def _init_weights(self, module): - # important: this ported version of Gemma2 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Gemma3nRMSNorm): - if module.with_scale: - module.weight.data.fill_(1.0) - elif isinstance(module, Gemma3nAudioCumulativeGroupNorm): + Gemma2PreTrainedModel._init_weights(module) + if isinstance(module, Gemma3nAudioCumulativeGroupNorm): module.weight.data.fill_(1.0) elif isinstance(module, Gemma3nAudioAttention): module.per_dim_scale.data.zero_() diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 72733a80f7..dc6cceb052 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -338,19 +338,6 @@ class GlmPreTrainedModel(PreTrainedModel): "attentions": GlmAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, GlmRMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class GlmModel(GlmPreTrainedModel): diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index e4dd64102d..6c7aeaf02c 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -342,19 +342,6 @@ class Glm4PreTrainedModel(PreTrainedModel): "attentions": Glm4Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Glm4RMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class Glm4Model(Glm4PreTrainedModel): diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index 31ad8ede95..cba4ee896b 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -411,19 +411,9 @@ class Glm4MoePreTrainedModel(PreTrainedModel): } def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Glm4MoeRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, Glm4MoeTopkRouter): - module.weight.data.normal_(mean=0.0, std=std) + super()._init_weights(module) + if isinstance(module, Glm4MoeTopkRouter): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) class Glm4MoeRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 41e37b3e1a..57d99a1f98 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -410,22 +410,6 @@ class Glm4vPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True - def _init_weights(self, module): - std = self.config.get_text_config().initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Glm4vRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - class Glm4vVisionModel(Glm4vPreTrainedModel): config: Glm4vVisionConfig diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 789744e326..71cd472cdb 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -525,22 +525,6 @@ class Glm4vVisionBlock(Qwen2_5_VLVisionBlock): class Glm4vPreTrainedModel(Qwen2_5_VLPreTrainedModel): _no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"] - def _init_weights(self, module): - std = self.config.get_text_config().initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Glm4vRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - class Glm4vVisionModel(Glm4vPreTrainedModel): config: Glm4vVisionConfig diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 99959cd74b..dead80a503 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -288,16 +288,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel): _supports_attention_backend = True def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, (nn.LayerNorm, GotOcr2LayerNorm)): # noqa: F821 - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, GotOcr2VisionAttention): + super()._init_weights(module) + if isinstance(module, GotOcr2VisionAttention): if module.use_rel_pos: module.rel_pos_h.data.zero_() module.rel_pos_w.data.zero_() diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index 06b9fca298..7b381e08cc 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -291,16 +291,8 @@ class GotOcr2PreTrainedModel(LlavaPreTrainedModel): _supports_flex_attn = False def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, (nn.LayerNorm, GotOcr2LayerNorm)): # noqa: F821 - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, GotOcr2VisionAttention): + LlavaPreTrainedModel._init_weights(module) + if isinstance(module, GotOcr2VisionAttention): if module.use_rel_pos: module.rel_pos_h.data.zero_() module.rel_pos_w.data.zero_() diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 15cc664d74..ec9efd340c 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -372,20 +372,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): } _keys_to_ignore_on_load_unexpected = [r"attention.bias", r"attention.masked_bias"] - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - @auto_docstring class GPTNeoXModel(GPTNeoXPreTrainedModel): diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 704b69aa5d..860c1ea99a 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -247,20 +247,6 @@ class GPTNeoXPreTrainedModel(LlamaPreTrainedModel): _no_split_modules = ["GPTNeoXLayer"] _keys_to_ignore_on_load_unexpected = [r"attention.bias", r"attention.masked_bias"] - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - GPT_NEOX_START_DOCSTRING = None # Will be picked up by modular GPT_NEOX_INPUTS_DOCSTRING = None # Will be picked up by modular diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 8bebef03c2..ba2163c435 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -316,19 +316,6 @@ class GranitePreTrainedModel(PreTrainedModel): "attentions": GraniteAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, GraniteRMSNorm): - module.weight.data.fill_(1.0) - class GraniteRotaryEmbedding(nn.Module): def __init__(self, config: GraniteConfig, device=None): diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index bf72cc85da..8a51c28053 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -595,17 +595,8 @@ class GraniteMoePreTrainedModel(PreTrainedModel): _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, GraniteMoeRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, GraniteMoeParallelExperts): + super()._init_weights(module) + if isinstance(module, GraniteMoeParallelExperts): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 8b3f3d1dcc..f569be5ef1 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1212,24 +1212,10 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel): _is_stateful = True def _init_weights(self, module): - if isinstance(module, nn.Linear): + super()._init_weights(module) + if isinstance(module, GraniteMoeHybridParallelExperts): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, GraniteMoeHybridRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, GraniteMoeHybridParallelExperts): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - # Initialize Mamba modules - if isinstance(module, (nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, GraniteMoeHybridMambaLayer): + if isinstance(module, GraniteMoeHybridMambaLayer): module.dt_bias.data.fill_(1.0) module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) module.D.data.fill_(1.0) diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 80c5cfd430..4274af1b19 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -167,13 +167,8 @@ class GraniteMoeHybridPreTrainedModel(GraniteMoeSharedPreTrainedModel): _is_stateful = True def _init_weights(self, module): - super()._init_weights() - # Initialize Mamba modules - if isinstance(module, (nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, GraniteMoeHybridMambaLayer): + super()._init_weights(module) + if isinstance(module, GraniteMoeHybridMambaLayer): module.dt_bias.data.fill_(1.0) module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) module.D.data.fill_(1.0) diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 83f78ae327..9afb268a49 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -513,17 +513,8 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel): _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, GraniteMoeSharedRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, GraniteMoeSharedParallelExperts): + super()._init_weights(module) + if isinstance(module, GraniteMoeSharedParallelExperts): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index f15fbd48dd..83fa413e5a 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -323,19 +323,6 @@ class HeliumPreTrainedModel(PreTrainedModel): "attentions": HeliumAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, HeliumRMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class HeliumModel(HeliumPreTrainedModel): diff --git a/src/transformers/models/hgnet_v2/modeling_hgnet_v2.py b/src/transformers/models/hgnet_v2/modeling_hgnet_v2.py index e9620ade40..16dfeb77c2 100644 --- a/src/transformers/models/hgnet_v2/modeling_hgnet_v2.py +++ b/src/transformers/models/hgnet_v2/modeling_hgnet_v2.py @@ -45,16 +45,6 @@ class HGNetV2PreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" _no_split_modules = ["HGNetV2BasicLayer"] - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.BatchNorm2d): - module.weight.data.fill_(1.0) - if module.bias is not None: - module.bias.data.zero_() - class HGNetV2LearnableAffineBlock(nn.Module): def __init__(self, scale_value: float = 1.0, bias_value: float = 0.0): diff --git a/src/transformers/models/hgnet_v2/modular_hgnet_v2.py b/src/transformers/models/hgnet_v2/modular_hgnet_v2.py index f5b8735f46..f0c90ce0a6 100644 --- a/src/transformers/models/hgnet_v2/modular_hgnet_v2.py +++ b/src/transformers/models/hgnet_v2/modular_hgnet_v2.py @@ -170,16 +170,6 @@ class HGNetV2PreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" _no_split_modules = ["HGNetV2BasicLayer"] - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.BatchNorm2d): - module.weight.data.fill_(1.0) - if module.bias is not None: - module.bias.data.zero_() - class HGNetV2LearnableAffineBlock(nn.Module): def __init__(self, scale_value: float = 1.0, bias_value: float = 0.0): diff --git a/src/transformers/models/informer/configuration_informer.py b/src/transformers/models/informer/configuration_informer.py index 559c482ea3..f62417358c 100644 --- a/src/transformers/models/informer/configuration_informer.py +++ b/src/transformers/models/informer/configuration_informer.py @@ -132,6 +132,7 @@ class InformerConfig(PretrainedConfig): "hidden_size": "d_model", "num_attention_heads": "encoder_attention_heads", "num_hidden_layers": "encoder_layers", + "initializer_range": "init_std", } def __init__( diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index ed4cda28e4..34db267e7f 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -257,20 +257,9 @@ class InformerPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True def _init_weights(self, module: nn.Module): - std = self.config.init_std - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, InformerSinusoidalPositionalEmbedding): + super()._init_weights(module) + if isinstance(module, InformerSinusoidalPositionalEmbedding): module._init_weight() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 0b3ecb5936..84ddd83caa 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -98,20 +98,9 @@ class InformerPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True def _init_weights(self, module: nn.Module): - std = self.config.init_std - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, InformerSinusoidalPositionalEmbedding): + super()._init_weights(module) + if isinstance(module, InformerSinusoidalPositionalEmbedding): module._init_weight() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 46ef56dc46..55b9b6dcfe 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -185,20 +185,8 @@ class InternVLVisionPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, InternVLVisionEmbeddings): + super()._init_weights(module) + if isinstance(module, InternVLVisionEmbeddings): module.cls_token.data.zero_() if module.mask_token is not None: module.mask_token.data.zero_() @@ -528,17 +516,6 @@ class InternVLPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True - def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - class InternVLMultiModalProjector(nn.Module): def __init__(self, config: InternVLConfig): diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 45d8c3be3e..67d864eec7 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -147,20 +147,8 @@ class InternVLVisionPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, InternVLVisionEmbeddings): + super()._init_weights(module) + if isinstance(module, InternVLVisionEmbeddings): module.cls_token.data.zero_() if module.mask_token is not None: module.mask_token.data.zero_() @@ -466,16 +454,7 @@ class InternVLVisionModel(InternVLVisionPreTrainedModel): class InternVLPreTrainedModel(LlavaPreTrainedModel): - def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + pass INTERNVL_INPUTS_DOCSTRING = None diff --git a/src/transformers/models/janus/configuration_janus.py b/src/transformers/models/janus/configuration_janus.py index 5063640bfc..a0e758fd9e 100644 --- a/src/transformers/models/janus/configuration_janus.py +++ b/src/transformers/models/janus/configuration_janus.py @@ -311,6 +311,7 @@ class JanusConfig(PretrainedConfig): f" Type found: {type(vq_config)}" ) + self.initializer_range = self.vision_config.initializer_range # This dimension is required when decoding discrete image tokens to continuous input. self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size # The default is only the index for the 1B model, 7B uses a different one diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 2a2257dfec..0553326af1 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -66,24 +66,6 @@ class JanusPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_param_buffer_assignment = False - def _init_weights(self, module): - std = ( - self.config.vision_config.initializer_range - if hasattr(self.config, "vision_config") - else self.config.initializer_range - ) - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - @dataclass @auto_docstring( diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 81e6390183..69b9293374 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -373,6 +373,7 @@ class JanusConfig(PretrainedConfig): f" Type found: {type(vq_config)}" ) + self.initializer_range = self.vision_config.initializer_range # This dimension is required when decoding discrete image tokens to continuous input. self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size # The default is only the index for the 1B model, 7B uses a different one @@ -393,24 +394,6 @@ class JanusPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_param_buffer_assignment = False - def _init_weights(self, module): - std = ( - self.config.vision_config.initializer_range - if hasattr(self.config, "vision_config") - else self.config.initializer_range - ) - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - @dataclass @auto_docstring( diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 0d383769d1..21c3d417d0 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -549,19 +549,6 @@ class Lfm2PreTrainedModel(PreTrainedModel): "attentions": Lfm2Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Lfm2RMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class Lfm2Model(Lfm2PreTrainedModel): diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 338e6ec524..1cdb15d370 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -403,19 +403,6 @@ class Lfm2DecoderLayer(GradientCheckpointingLayer): class Lfm2PreTrainedModel(LlamaPreTrainedModel): _supports_static_cache = False - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Lfm2RMSNorm): - module.weight.data.fill_(1.0) - class Lfm2Model(LlamaModel): def __init__(self, config: Lfm2Config): diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py index e635f13e33..9c768e80af 100644 --- a/src/transformers/models/lightglue/modeling_lightglue.py +++ b/src/transformers/models/lightglue/modeling_lightglue.py @@ -426,16 +426,6 @@ class LightGluePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - def _init_weights(self, module: nn.Module) -> None: - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]: """obtain matches from a score matrix [Bx M+1 x N+1]""" diff --git a/src/transformers/models/lightglue/modular_lightglue.py b/src/transformers/models/lightglue/modular_lightglue.py index 78caf28f15..d90708e6bd 100644 --- a/src/transformers/models/lightglue/modular_lightglue.py +++ b/src/transformers/models/lightglue/modular_lightglue.py @@ -511,16 +511,6 @@ class LightGluePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - def _init_weights(self, module: nn.Module) -> None: - """Initialize the weights""" - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]: """obtain matches from a score matrix [Bx M+1 x N+1]""" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4bab75a87c..9d194da3ca 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -322,19 +322,6 @@ class LlamaPreTrainedModel(PreTrainedModel): "attentions": LlamaAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, LlamaRMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class LlamaModel(LlamaPreTrainedModel): diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 032751a4e1..56f056023b 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -125,20 +125,6 @@ class LlavaPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True - def _init_weights(self, module): - # important: this ported version of Llava isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 6923fdc91a..da427b8b7e 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -597,19 +597,6 @@ class MiniMaxPreTrainedModel(PreTrainedModel): "attentions": [MiniMaxAttention, MiniMaxLightningAttention], } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, MiniMaxRMSNorm): - module.weight.data.fill_(1.0) - class MiniMaxRotaryEmbedding(nn.Module): def __init__(self, config: MiniMaxConfig, device=None): diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 1189f1b34d..4e0d30d89a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -267,19 +267,6 @@ class MistralPreTrainedModel(PreTrainedModel): "attentions": MistralAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, MistralRMSNorm): - module.weight.data.fill_(1.0) - class MistralRotaryEmbedding(nn.Module): def __init__(self, config: MistralConfig, device=None): diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index d1a9c83f9d..cf46657627 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -190,22 +190,6 @@ class Mistral3PreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True - def _init_weights(self, module): - # important: this ported version of Mistral3 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, Mistral3RMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index 5b5f27579c..da507655fa 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -115,21 +115,7 @@ class Mistral3ModelOutputWithPast(LlavaModelOutputWithPast): class Mistral3PreTrainedModel(LlavaPreTrainedModel): - def _init_weights(self, module): - # important: this ported version of Mistral3 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, Mistral3RMSNorm): - module.weight.data.fill_(1.0) + pass class Mistral3Model(LlavaModel): diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index ec18500269..2b700467d0 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -396,19 +396,6 @@ class MixtralPreTrainedModel(PreTrainedModel): "attentions": MixtralAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, MixtralRMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class MixtralModel(MixtralPreTrainedModel): diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 9b229e4074..1143ee87ec 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -465,21 +465,6 @@ class MoonshinePreTrainedModel(PreTrainedModel): _supports_static_cache = True # TODO arthur, how do we separate when it cross / self coming from different layer? - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)): - module.weight.data.fill_(1.0) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 9706d99d7c..4619379352 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -500,21 +500,6 @@ class MoonshinePreTrainedModel(PreTrainedModel): _supports_static_cache = True # TODO arthur, how do we separate when it cross / self coming from different layer? - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)): - module.weight.data.fill_(1.0) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers diff --git a/src/transformers/models/musicgen/configuration_musicgen.py b/src/transformers/models/musicgen/configuration_musicgen.py index 1618af7e43..878cc122f1 100644 --- a/src/transformers/models/musicgen/configuration_musicgen.py +++ b/src/transformers/models/musicgen/configuration_musicgen.py @@ -214,6 +214,7 @@ class MusicgenConfig(PretrainedConfig): self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) self.decoder = MusicgenDecoderConfig(**decoder_config) self.is_encoder_decoder = True + self.initializer_factor = self.decoder.initializer_factor @classmethod def from_sub_models_config( diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 64edaee56f..cd76c87162 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1333,14 +1333,11 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin): The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder, """ ) -class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): +class MusicgenForConditionalGeneration(MusicgenPreTrainedModel, GenerationMixin): config: MusicgenConfig base_model_prefix = "encoder_decoder" main_input_name = "input_ids" supports_gradient_checkpointing = True - _supports_flash_attn = True - _supports_sdpa = True - _supports_flex_attn = True def __init__( self, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 817341cc25..2a4756e307 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -301,17 +301,6 @@ class OlmoPreTrainedModel(PreTrainedModel): "attentions": OlmoAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - @auto_docstring class OlmoModel(OlmoPreTrainedModel): diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 44678f0939..6bc439888e 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -14,7 +14,6 @@ from ..llama.modeling_llama import ( LlamaForCausalLM, LlamaMLP, LlamaModel, - LlamaPreTrainedModel, LlamaRotaryEmbedding, eager_attention_forward, rotate_half, @@ -153,19 +152,6 @@ class OlmoRotaryEmbedding(LlamaRotaryEmbedding): return cos, sin -class OlmoPreTrainedModel(LlamaPreTrainedModel): - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - class OlmoModel(LlamaModel): def __init__(self, config: OlmoConfig): super().__init__(config) @@ -179,4 +165,8 @@ class OlmoForCausalLM(LlamaForCausalLM): pass -__all__ = ["OlmoForCausalLM", "OlmoModel", "OlmoPreTrainedModel"] +__all__ = [ + "OlmoForCausalLM", + "OlmoModel", + "OlmoPreTrainedModel", # noqa: F822 +] diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index b589364df5..097b7a8d16 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -306,19 +306,6 @@ class Olmo2PreTrainedModel(PreTrainedModel): "attentions": Olmo2Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Olmo2RMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class Olmo2Model(Olmo2PreTrainedModel): diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index 657aa80569..37e1bed147 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -99,20 +99,6 @@ class PerceptionLMPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True - def _init_weights(self, module): - # important: this ported version of PerceptionLM isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/PerceptionLM/tree/main/perception_lm should serve for that purpose - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - @dataclass @auto_docstring( diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index ea77b2d471..dcef7cdcf4 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -306,20 +306,6 @@ class PhiPreTrainedModel(PreTrainedModel): "attentions": PhiAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - @auto_docstring class PhiModel(PhiPreTrainedModel): diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 1448c24827..f197abf129 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -19,7 +19,6 @@ from ..llama.modeling_llama import ( LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, - LlamaPreTrainedModel, LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, # copied from Llama @@ -169,22 +168,6 @@ class PhiRotaryEmbedding(LlamaRotaryEmbedding): pass -class PhiPreTrainedModel(LlamaPreTrainedModel): - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - - class PhiModel(LlamaModel): def __init__(self, config: PhiConfig): super().__init__(config) @@ -307,7 +290,7 @@ class PhiForTokenClassification(LlamaForTokenClassification): __all__ = [ - "PhiPreTrainedModel", + "PhiPreTrainedModel", # noqa: F822 "PhiModel", "PhiForCausalLM", "PhiForSequenceClassification", diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index c896def491..32e1f2e0aa 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -299,19 +299,6 @@ class Phi3PreTrainedModel(PreTrainedModel): } _version = "0.0.5" - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Phi3RMSNorm): - module.weight.data.fill_(1.0) - class Phi3RotaryEmbedding(nn.Module): def __init__(self, config: Phi3Config, device=None): diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 855e8b7fc1..a984a7957c 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1000,19 +1000,8 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): _supports_flex_attn = True def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv): + super()._init_weights(module) + if isinstance(module, Phi4MultimodalAudioGluPointWiseConv): module.b1.data.zero_() module.b2.data.zero_() @@ -1602,18 +1591,8 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): _version = "0.0.5" def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Phi4MultimodalRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, Phi4MultimodalImageEmbedding): + super()._init_weights(module) + if isinstance(module, Phi4MultimodalImageEmbedding): module.global_img_feature_extensor.data.zero_() module.sub_img_feature_extensor.data.zero_() diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index fefe5d69ab..7208cab16f 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -1125,19 +1125,8 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): _supports_flex_attn = True def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv): + super()._init_weights(module) + if isinstance(module, Phi4MultimodalAudioGluPointWiseConv): module.b1.data.zero_() module.b2.data.zero_() @@ -1460,18 +1449,8 @@ class Phi4MultimodalRotaryEmbedding(Phi3RotaryEmbedding): class Phi4MultimodalPreTrainedModel(Phi3PreTrainedModel): def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Phi4MultimodalRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, Phi4MultimodalImageEmbedding): + Phi3PreTrainedModel._init_weights(module) + if isinstance(module, Phi4MultimodalImageEmbedding): module.global_img_feature_extensor.data.zero_() module.sub_img_feature_extensor.data.zero_() diff --git a/src/transformers/models/plbart/configuration_plbart.py b/src/transformers/models/plbart/configuration_plbart.py index a605f9a1c8..a4aaa3ff37 100644 --- a/src/transformers/models/plbart/configuration_plbart.py +++ b/src/transformers/models/plbart/configuration_plbart.py @@ -101,7 +101,11 @@ class PLBartConfig(PretrainedConfig): model_type = "plbart" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + attribute_map = { + "num_attention_heads": "encoder_attention_heads", + "hidden_size": "d_model", + "initializer_range": "init_std", + } def __init__( self, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 4236476349..541178bd38 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -80,17 +80,6 @@ class PLBartPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( self, diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 3de32a625a..c0fcdea7af 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -66,17 +66,6 @@ class PLBartPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( self, diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index f8e9092a10..986e6e7e75 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -245,13 +245,6 @@ class PromptDepthAnythingPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - class PromptDepthAnythingReassembleLayer(nn.Module): def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int): diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index fcd3c9c91f..988143000c 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -164,13 +164,6 @@ class PromptDepthAnythingPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - class PromptDepthAnythingReassembleLayer(nn.Module): def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int): diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index e8c5a1bc8a..95b08ce24e 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -270,19 +270,6 @@ class Qwen2PreTrainedModel(PreTrainedModel): "attentions": Qwen2Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Qwen2RMSNorm): - module.weight.data.fill_(1.0) - class Qwen2RotaryEmbedding(nn.Module): def __init__(self, config: Qwen2Config, device=None): diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 37c0c6da43..b8da2d7846 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -58,26 +58,6 @@ from .configuration_qwen2_5_omni import ( logger = logging.get_logger(__name__) -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - @auto_docstring class Qwen2_5OmniPreTrainedModel(PreTrainedModel): config: Qwen2_5OmniConfig @@ -90,27 +70,6 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): _supports_static_cache = False _supports_attention_backend = True - def _init_weights(self, module): - # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 - - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if module.weight is not None: - module.weight.data.fill_(1.0) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, Qwen2RMSNorm): - module.weight.data.fill_(1.0) - class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel): def _prepare_4d_causal_attention_mask_with_cache_position( @@ -1026,6 +985,26 @@ class Qwen2_5OmniMLP(nn.Module): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer): def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None: super().__init__() diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index d40b0a073c..f899d0f8c4 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -35,7 +35,6 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLPreTrainedModel, Qwen2_5_VLTextModel, Qwen2_5_VLVisionBlock, - Qwen2RMSNorm, eager_attention_forward, ) from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioEncoderConfig @@ -1136,27 +1135,6 @@ class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): config: Qwen2_5OmniConfig _supports_static_cache = False - def _init_weights(self, module): - # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 - - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if module.weight is not None: - module.weight.data.fill_(1.0) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, Qwen2RMSNorm): - module.weight.data.fill_(1.0) - class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel): def _prepare_4d_causal_attention_mask_with_cache_position( diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 66fb7a7c06..d653b9de87 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -329,19 +329,6 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True - def _init_weights(self, module): - std = self.config.get_text_config().initializer_range - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Qwen2RMSNorm): - module.weight.data.fill_(1.0) - class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): config: Qwen2_5_VLVisionConfig diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 550df3750a..7c4e7117a2 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -174,18 +174,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): - def _init_weights(self, module): - std = self.config.get_text_config().initializer_range - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Qwen2RMSNorm): - module.weight.data.fill_(1.0) + pass class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index b14264f455..6c17e0c9c4 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -663,22 +663,6 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _supports_static_cache = True _supports_attention_backend = True - def _init_weights(self, module): - std = self.config.get_text_config().initializer_range - if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, Qwen2RMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 73f0631480..7cd5eb52b8 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -296,19 +296,6 @@ class Qwen3PreTrainedModel(PreTrainedModel): "attentions": Qwen3Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Qwen3RMSNorm): - module.weight.data.fill_(1.0) - class Qwen3RotaryEmbedding(nn.Module): def __init__(self, config: Qwen3Config, device=None): diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index f37568777b..b7151bba05 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -419,19 +419,6 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): "attentions": Qwen3MoeAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Qwen3MoeRMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class Qwen3MoeModel(Qwen3MoePreTrainedModel): diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 474e30c584..ac030a4e86 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1017,19 +1017,8 @@ class SamPreTrainedModel(PreTrainedModel): _supports_sdpa = True def _init_weights(self, module: nn.Module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, (SamLayerNorm, nn.LayerNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, SamVisionAttention): + super()._init_weights(module) + if isinstance(module, SamVisionAttention): if module.use_rel_pos: module.rel_pos_h.data.zero_() module.rel_pos_w.data.zero_() diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index fac2b15500..5c688e5089 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -482,20 +482,9 @@ class SamHQPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, (SamHQLayerNorm, nn.LayerNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - elif isinstance(module, SamHQVisionAttention): + def _init_weights(self, module: nn.Module): + super()._init_weights(module) + if isinstance(module, SamHQVisionAttention): if module.use_rel_pos: module.rel_pos_h.data.zero_() module.rel_pos_w.data.zero_() diff --git a/src/transformers/models/sam_hq/modular_sam_hq.py b/src/transformers/models/sam_hq/modular_sam_hq.py index 2a241fb2c0..2ea7596647 100644 --- a/src/transformers/models/sam_hq/modular_sam_hq.py +++ b/src/transformers/models/sam_hq/modular_sam_hq.py @@ -186,8 +186,7 @@ class SamHQVisionLayer(SamVisionLayer): class SamHQPreTrainedModel(SamPreTrainedModel): - def _init_weights(self, module): - super()._init_weights(module) + pass class SamHQVisionEncoder(SamVisionEncoder, SamHQPreTrainedModel): diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 92a3205a7b..c4287d1d23 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -300,19 +300,6 @@ class SmolLM3PreTrainedModel(PreTrainedModel): "attentions": SmolLM3Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, SmolLM3RMSNorm): - module.weight.data.fill_(1.0) - class SmolLM3RotaryEmbedding(nn.Module): def __init__(self, config: SmolLM3Config, device=None): diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index d83fe01c14..a9d8f043c4 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -47,6 +47,26 @@ from .configuration_smolvlm import SmolVLMConfig, SmolVLMVisionConfig logger = logging.get_logger(__name__) +class SmolVLMRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + SmolVLMRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + @auto_docstring class SmolVLMPreTrainedModel(PreTrainedModel): config: SmolVLMConfig @@ -74,6 +94,8 @@ class SmolVLMPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() + elif isinstance(module, SmolVLMRMSNorm): + module.weight.data.fill_(1.0) class SmolVLMVisionEmbeddings(nn.Module): diff --git a/src/transformers/models/smolvlm/modular_smolvlm.py b/src/transformers/models/smolvlm/modular_smolvlm.py index dfda67472c..72d7c85877 100644 --- a/src/transformers/models/smolvlm/modular_smolvlm.py +++ b/src/transformers/models/smolvlm/modular_smolvlm.py @@ -95,20 +95,7 @@ class SmolVLMVisionConfig(Idefics3VisionConfig): class SmolVLMPreTrainedModel(Idefics3PreTrainedModel): - def _init_weights(self, module): - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + pass class SmolVLMVisionTransformer(Idefics3VisionTransformer): diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 9e574d5349..87da8a3f7f 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -304,20 +304,6 @@ class Starcoder2PreTrainedModel(PreTrainedModel): "attentions": Starcoder2Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - @auto_docstring class Starcoder2Model(Starcoder2PreTrainedModel): diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 8157f37b6d..349ffb8acb 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -42,7 +42,6 @@ from ..mistral.modeling_mistral import ( MistralForSequenceClassification, MistralForTokenClassification, MistralModel, - MistralPreTrainedModel, MistralRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, @@ -141,22 +140,6 @@ class Starcoder2RotaryEmbedding(MistralRotaryEmbedding): pass -class Starcoder2PreTrainedModel(MistralPreTrainedModel): - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - - class Starcoder2Model(MistralModel): def __init__(self, config: Starcoder2Config): super().__init__(config) diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index c2fdbf5fc7..41ce05b295 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -594,18 +594,9 @@ class T5GemmaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): # TODO: support intialization for encoders and decoders separately(?) + super()._init_weights(module) std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, T5GemmaRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, T5GemmaClassificationHead): + if isinstance(module, T5GemmaClassificationHead): scale = module.out_proj.weight.shape[0] ** -0.5 module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 5c72d76b4e..540a796f60 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -485,18 +485,9 @@ class T5GemmaPreTrainedModel(Gemma2PreTrainedModel): def _init_weights(self, module): # TODO: support intialization for encoders and decoders separately(?) + Gemma2PreTrainedModel._init_weights(module) std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, T5GemmaRMSNorm): - module.weight.data.fill_(1.0) - elif isinstance(module, T5GemmaClassificationHead): + if isinstance(module, T5GemmaClassificationHead): scale = module.out_proj.weight.shape[0] ** -0.5 module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 8a6686ac80..dd6f352376 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -306,22 +306,8 @@ class TimesFmPreTrainedModel(PreTrainedModel): _supports_sdpa = True def _init_weights(self, module): - if isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0, std=self.config.initializer_range) - - elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0, std=self.config.initializer_range) - if module.bias is not None: - nn.init.zeros_(module.bias) - - elif isinstance(module, nn.LayerNorm): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) - - elif isinstance(module, TimesFmRMSNorm): - nn.init.zeros_(module.weight) - - elif isinstance(module, TimesFmAttention): + super()._init_weights(module) + if isinstance(module, TimesFmAttention): # Initialize scaling parameter nn.init.ones_(module.scaling) diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 683537637b..b82816e7c7 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -262,22 +262,8 @@ class TimesFmPreTrainedModel(PreTrainedModel): _supports_sdpa = True def _init_weights(self, module): - if isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0, std=self.config.initializer_range) - - elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0, std=self.config.initializer_range) - if module.bias is not None: - nn.init.zeros_(module.bias) - - elif isinstance(module, nn.LayerNorm): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) - - elif isinstance(module, TimesFmRMSNorm): - nn.init.zeros_(module.weight) - - elif isinstance(module, TimesFmAttention): + super()._init_weights(module) + if isinstance(module, TimesFmAttention): # Initialize scaling parameter nn.init.ones_(module.scaling) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index d3c263807b..2f3f536f79 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -126,20 +126,6 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True - def _init_weights(self, module): - # important: this ported version of VipLlava isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/VipLlava/tree/main/vipllava should serve for that purpose - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index fe1482bcdf..17a96aaad8 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1183,18 +1183,8 @@ class Zamba2PreTrainedModel(PreTrainedModel): _is_stateful = True def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, (Zamba2RMSNorm, Zamba2RMSNormGated)): - module.weight.data.fill_(1.0) - elif isinstance(module, Zamba2MambaMixer): + super()._init_weights(module) + if isinstance(module, Zamba2MambaMixer): dt = torch.exp( torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 05565c60d6..2e2c4cd608 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -906,18 +906,8 @@ class Zamba2PreTrainedModel(PreTrainedModel): _is_stateful = True def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, (Zamba2RMSNorm, Zamba2RMSNormGated)): - module.weight.data.fill_(1.0) - elif isinstance(module, Zamba2MambaMixer): + super()._init_weights(module) + if isinstance(module, Zamba2MambaMixer): dt = torch.exp( torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))