From 4e538409203402058d1ee75c57e03dd908eac508 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 14 Apr 2025 16:19:04 +0200 Subject: [PATCH] Detect and fix most `_init_weights()` issues - make it work for composite models (#37070) * Update test_modeling_common.py * Fix Llama and its modular children * Update test_modeling_common.py * qwen3 * first try at prioritizing models * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * test * fix * fix * more models * more * more * more * smarter init for composite models! * fix post rebase * smol * fix missing args * more * typo * Super elegant and efficient init for submodels * Update modeling_utils.py * style * last fixes * cleanup * finalize cleanup * CIs * improve docstring * Update modeling_utils.py * llama4 * style * CIs * style * add dpt * granite speech * qwen 2.5 omni * better fix * Parse the config file instead * CIs --- src/transformers/modeling_utils.py | 37 +- src/transformers/models/aria/modeling_aria.py | 17 +- src/transformers/models/aria/modular_aria.py | 17 +- .../models/aya_vision/modeling_aya_vision.py | 15 +- .../models/aya_vision/modular_aya_vision.py | 15 + .../models/bamba/modeling_bamba.py | 6 + .../models/bamba/modular_bamba.py | 6 + .../models/blip_2/modeling_blip_2.py | 30 +- .../models/chameleon/modeling_chameleon.py | 22 +- .../models/cohere/modeling_cohere.py | 2 + .../models/cohere/modular_cohere.py | 16 + .../models/cohere2/modeling_cohere2.py | 2 + .../deepseek_v3/modeling_deepseek_v3.py | 4 +- .../models/deepseek_v3/modular_deepseek_v3.py | 4 +- .../models/diffllama/modeling_diffllama.py | 7 + .../models/diffllama/modular_diffllama.py | 18 + src/transformers/models/dpt/modeling_dpt.py | 2 +- src/transformers/models/emu3/modeling_emu3.py | 18 +- src/transformers/models/emu3/modular_emu3.py | 18 +- .../models/gemma/modeling_gemma.py | 2 + .../models/gemma2/modeling_gemma2.py | 2 + .../models/gemma3/modeling_gemma3.py | 12 +- .../models/gemma3/modular_gemma3.py | 12 +- src/transformers/models/glm/modeling_glm.py | 2 + src/transformers/models/glm4/modeling_glm4.py | 2 + .../models/got_ocr2/modeling_got_ocr2.py | 26 +- .../models/got_ocr2/modular_got_ocr2.py | 18 +- .../modeling_gpt_neox_japanese.py | 3 + .../models/granite/modeling_granite.py | 2 + .../granite_speech/modeling_granite_speech.py | 6 +- .../models/granitemoe/modeling_granitemoe.py | 3 +- .../modeling_granitemoeshared.py | 3 +- .../models/helium/modeling_helium.py | 2 + .../models/idefics/modeling_idefics.py | 24 +- .../models/idefics2/configuration_idefics2.py | 4 + .../models/idefics2/modeling_idefics2.py | 20 +- .../models/idefics3/modeling_idefics3.py | 15 +- .../instructblip/modeling_instructblip.py | 22 +- .../modeling_instructblipvideo.py | 377 +++++++++--------- .../models/jamba/modeling_jamba.py | 7 + .../models/jetmoe/modeling_jetmoe.py | 3 +- .../models/llama/modeling_llama.py | 2 + .../models/llama4/modeling_llama4.py | 11 + .../models/llava/modeling_llava.py | 15 +- .../models/llava_next/modeling_llava_next.py | 22 +- .../modeling_llava_next_video.py | 101 +++-- .../modular_llava_next_video.py | 16 +- .../modeling_llava_onevision.py | 23 +- src/transformers/models/mimi/modeling_mimi.py | 17 +- .../models/mistral/modeling_mistral.py | 2 + .../models/mistral3/modeling_mistral3.py | 20 +- .../models/mistral3/modular_mistral3.py | 14 +- .../models/mixtral/modeling_mixtral.py | 2 + .../models/mllama/modeling_mllama.py | 17 +- .../models/moonshine/modeling_moonshine.py | 4 + .../models/moonshine/modular_moonshine.py | 4 + .../models/moshi/modeling_moshi.py | 15 +- .../models/nemotron/modeling_nemotron.py | 3 + src/transformers/models/olmo/modeling_olmo.py | 68 ++-- src/transformers/models/olmo/modular_olmo.py | 16 +- .../models/olmo2/modeling_olmo2.py | 70 ++-- .../models/olmo2/modular_olmo2.py | 11 +- .../models/olmoe/modeling_olmoe.py | 2 + src/transformers/models/opt/modeling_opt.py | 3 + .../models/paligemma/modeling_paligemma.py | 15 +- .../models/persimmon/modeling_persimmon.py | 3 + src/transformers/models/phi/modeling_phi.py | 71 ++-- src/transformers/models/phi/modular_phi.py | 19 +- src/transformers/models/phi3/modeling_phi3.py | 2 + .../modeling_phi4_multimodal.py | 76 ++-- .../modular_phi4_multimodal.py | 36 +- .../models/phimoe/modeling_phimoe.py | 3 + .../models/pixtral/modeling_pixtral.py | 13 +- .../modeling_prompt_depth_anything.py | 7 +- .../modular_prompt_depth_anything.py | 7 +- .../models/qwen2/modeling_qwen2.py | 2 + .../configuration_qwen2_5_omni.py | 2 + .../qwen2_5_omni/modeling_qwen2_5_omni.py | 47 ++- .../qwen2_5_omni/modular_qwen2_5_omni.py | 10 +- .../qwen2_5_vl/configuration_qwen2_5_vl.py | 2 + .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 + .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 15 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 2 + .../models/qwen2_vl/configuration_qwen2_vl.py | 2 + .../models/qwen2_vl/modeling_qwen2_vl.py | 6 + .../models/qwen3/modeling_qwen3.py | 2 + .../models/qwen3_moe/modeling_qwen3_moe.py | 2 + .../modeling_recurrent_gemma.py | 7 + .../models/rt_detr/modeling_rt_detr.py | 12 +- .../models/rt_detr_v2/modeling_rt_detr_v2.py | 12 +- .../models/smolvlm/modeling_smolvlm.py | 12 +- .../models/smolvlm/modular_smolvlm.py | 15 +- .../models/stablelm/modeling_stablelm.py | 3 + .../models/starcoder2/modeling_starcoder2.py | 71 ++-- .../models/starcoder2/modular_starcoder2.py | 22 + .../models/upernet/modeling_upernet.py | 37 +- .../models/vipllava/modeling_vipllava.py | 22 +- .../models/whisper/modeling_whisper.py | 10 +- .../models/zamba/modeling_zamba.py | 13 +- .../models/zamba2/modeling_zamba2.py | 12 +- .../models/zamba2/modular_zamba2.py | 12 +- .../test_modeling_phi4_multimodal.py | 2 + tests/test_modeling_common.py | 70 ++++ 103 files changed, 1164 insertions(+), 795 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a73aecfb88..5a8ebc6847 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2449,6 +2449,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi self._init_weights(module) module._is_hf_initialized = True + @torch.no_grad() + def initialize_weights(self): + """ + This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models. + This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the + module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite + model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which + is extremely error prone and inefficient. + + Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use + `torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as + `module.weight.data.zero_()`. + """ + if not hasattr(torch.nn.Module, "smart_apply"): + # This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function + # to apply as we go down the graph + def smart_apply(self, fn): + for module in self.children(): + # We found a sub-model: recursively dispatch its own init function now! + if hasattr(module, "_init_weights"): + module.smart_apply(module._initialize_weights) + else: + module.smart_apply(fn) + fn(self) + return self + + torch.nn.Module.smart_apply = smart_apply + + # Let the magic happen with this simple call + self.smart_apply(self._initialize_weights) + def tie_weights(self): """ Tie the weights between the input embeddings and the output embeddings. @@ -3074,7 +3105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if _init_weights: # Initialize weights - self.apply(self._initialize_weights) + self.initialize_weights() # Tie weights should be skipped when not initializing all weights # since from_pretrained(...) calls tie weights anyways @@ -5286,9 +5317,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi ) ) with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): - self.apply(self._initialize_weights) + self.initialize_weights() else: - self.apply(self._initialize_weights) + self.initialize_weights() def get_parameter_or_buffer(self, target: str): """ diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index fdb825cad3..867ad34b55 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -679,12 +679,10 @@ class AriaTextPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, AriaTextRMSNorm): + module.weight.data.fill_(1.0) elif isinstance(module, AriaGroupedExpertsGemm): module.weight.data.normal_(mean=0.0, std=std) - elif isinstance(module, nn.Conv2d): - module.weight.data.normal_(mean=0.0, std=std) - if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() ARIA_TEXT_START_DOCSTRING = r""" @@ -724,14 +722,17 @@ class AriaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): std = self.config.initializer_range + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.MultiheadAttention): + # This uses torch's original init + module._reset_parameters() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() elif isinstance(module, AriaProjector): nn.init.trunc_normal_(module.query, std=std) diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index fa0858cde3..574ee053a9 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1255,12 +1255,10 @@ class AriaTextPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, AriaTextRMSNorm): + module.weight.data.fill_(1.0) elif isinstance(module, AriaGroupedExpertsGemm): module.weight.data.normal_(mean=0.0, std=std) - elif isinstance(module, nn.Conv2d): - module.weight.data.normal_(mean=0.0, std=std) - if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() class AriaPreTrainedModel(LlamaPreTrainedModel): @@ -1269,14 +1267,17 @@ class AriaPreTrainedModel(LlamaPreTrainedModel): def _init_weights(self, module): std = self.config.initializer_range + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.MultiheadAttention): + # This uses torch's original init + module._reset_parameters() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() elif isinstance(module, AriaProjector): nn.init.trunc_normal_(module.query, std=std) diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 1e6e76a210..8c1a3b23b9 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -127,26 +127,19 @@ class AyaVisionPreTrainedModel(PreTrainedModel): _supports_static_cache = False def _init_weights(self, module): - # important: this ported version of AyaVision isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/AyaVision/tree/main/aya_vision should serve for that purpose std = ( self.config.initializer_range if hasattr(self.config, "initializer_range") else self.config.text_config.initializer_range ) - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() @dataclass diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index b046275a2d..96f0de888d 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -113,6 +113,21 @@ class AyaVisionPreTrainedModel(LlavaPreTrainedModel): _supports_quantized_cache = False _supports_static_cache = False + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): pass diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 8fd2483bcd..5d70cf0dda 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1052,10 +1052,16 @@ class BambaPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)): + module.weight.data.fill_(1.0) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, BambaMixer): + module.dt_bias.data.fill_(1.0) + module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) + module.D.data.fill_(1.0) BAMBA_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 5aa4c8fb40..9c52fb16ac 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -820,10 +820,16 @@ class BambaPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)): + module.weight.data.fill_(1.0) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, BambaMixer): + module.dt_bias.data.fill_(1.0) + module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) + module.D.data.fill_(1.0) BAMBA_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 424fbf0a5a..1ee48b00b8 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -423,22 +423,30 @@ class Blip2PreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_range - if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear): + + if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=factor) - if hasattr(module, "bias") and module.bias is not None: + if module.bias is not None: module.bias.data.zero_() - - if isinstance(module, Blip2VisionEmbeddings): - if hasattr(self.config, "vision_config") and not isinstance(self.config, Blip2VisionConfig): - factor = self.config.vision_config.initializer_range - nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) - nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) - + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) - elif isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + elif isinstance(module, Blip2VisionEmbeddings): + nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) + nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) + elif isinstance( + module, + ( + Blip2Model, + Blip2TextModelWithProjection, + Blip2VisionModelWithProjection, + Blip2ForConditionalGeneration, + Blip2ForImageTextRetrieval, + ), + ): + module.query_tokens.data.zero_() BLIP_2_START_DOCSTRING = r""" diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 4d950ec668..ad0c255023 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1056,12 +1056,16 @@ class ChameleonPreTrainedModel(PreTrainedModel): def _init_weights(self, module): std = self.config.initializer_range - if isinstance(module, ChameleonVQVAE): - module.apply(module._init_weights) - elif isinstance(module, (nn.Linear, nn.Conv2d)): + + if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ChameleonRMSNorm): + module.weight.data.fill_(1.0) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: @@ -1096,18 +1100,6 @@ class ChameleonVQVAE(ChameleonPreTrainedModel): config_class = ChameleonVQVAEConfig _no_split_modules = ["ChameleonVQVAEVectorQuantizer"] - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - elif isinstance(module, nn.GroupNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - def __init__(self, config: ChameleonVQVAEConfig): super().__init__(config) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 8cbb7128c7..e8e6bba237 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -416,6 +416,8 @@ class CoherePreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, CohereLayerNorm): + module.weight.data.fill_(1.0) COHERE_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index 17644ff4f8..c9f2d8ff27 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -41,6 +41,7 @@ from ..llama.modeling_llama import ( LlamaForCausalLM, LlamaMLP, LlamaModel, + LlamaPreTrainedModel, LlamaRotaryEmbedding, eager_attention_forward, ) @@ -277,6 +278,21 @@ class CohereDecoderLayer(nn.Module): return outputs +class CoherePreTrainedModel(LlamaPreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, CohereLayerNorm): + module.weight.data.fill_(1.0) + + class CohereModel(LlamaModel): def __init__(self, config: CohereConfig): super().__init__(config) diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 18a3a50ac1..65e066f90c 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -424,6 +424,8 @@ class Cohere2PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Cohere2LayerNorm): + module.weight.data.fill_(1.0) COHERE2_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index c1a020b7c7..586d5251b0 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -557,10 +557,10 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, DeepseekV3RMSNorm): + module.weight.data.fill_(1.0) elif isinstance(module, DeepseekV3TopkRouter): module.weight.data.normal_(mean=0.0, std=std) - elif isinstance(module, nn.Parameter): - module.weight.data.normal_(mean=0.0, std=std) DEEPSEEK_V3_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 7713eb3b27..b4905c6201 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -347,10 +347,10 @@ class DeepseekV3PreTrainedModel(LlamaPreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, DeepseekV3RMSNorm): + module.weight.data.fill_(1.0) elif isinstance(module, DeepseekV3TopkRouter): module.weight.data.normal_(mean=0.0, std=std) - elif isinstance(module, nn.Parameter): - module.weight.data.normal_(mean=0.0, std=std) class DeepseekV3Model(LlamaModel): diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index e7fecb4be6..f6ff065334 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -625,6 +625,13 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821 + module.weight.data.fill_(1.0) + elif isinstance(module, DiffLlamaAttention): + module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k2.data.normal_(0, self.config.lambda_std_dev) class DiffLlamaRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 6810234790..f7bc2d2c5a 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -431,6 +431,24 @@ class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): _supports_flex_attn = False _supports_attention_backend = False + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821 + module.weight.data.fill_(1.0) + elif isinstance(module, DiffLlamaAttention): + module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k2.data.normal_(0, self.config.lambda_std_dev) + class DiffLlamaModel(LlamaModel): pass diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index c69bf618fe..0dc1dcf2f8 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -852,7 +852,7 @@ class DPTPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)): diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index fcc55b67d1..1d85fcb639 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1020,6 +1020,10 @@ class Emu3VQVAE(PreTrainedModel): def _init_weights(self, module): if isinstance(module, (nn.Conv2d, nn.Conv3d)): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(module.bias, -bound, bound) elif isinstance(module, nn.Linear): nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: @@ -1027,8 +1031,12 @@ class Emu3VQVAE(PreTrainedModel): bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(module.bias, -bound, bound) elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_() + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() def __init__(self, config: Emu3VQVAEConfig): super().__init__(config) @@ -1198,9 +1206,7 @@ class Emu3PreTrainedModel(PreTrainedModel): def _init_weights(self, module): std = self.config.get_text_config().initializer_range - if isinstance(module, Emu3VQVAE): - module.apply(module._init_weights) - elif isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() @@ -1208,6 +1214,8 @@ class Emu3PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Emu3RMSNorm): # noqa: F821 + module.weight.data.fill_(1.0) class Emu3RotaryEmbedding(nn.Module): diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 031dc26f0a..62e95a7f73 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -747,6 +747,10 @@ class Emu3VQVAE(PreTrainedModel): def _init_weights(self, module): if isinstance(module, (nn.Conv2d, nn.Conv3d)): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(module.bias, -bound, bound) elif isinstance(module, nn.Linear): nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: @@ -754,8 +758,12 @@ class Emu3VQVAE(PreTrainedModel): bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(module.bias, -bound, bound) elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): - nn.init.constant_(module.weight, 1) - nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_() + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() def __init__(self, config: Emu3VQVAEConfig): super().__init__(config) @@ -894,9 +902,7 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE): def _init_weights(self, module): std = self.config.get_text_config().initializer_range - if isinstance(module, Emu3VQVAE): - module.apply(module._init_weights) - elif isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() @@ -904,6 +910,8 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Emu3RMSNorm): # noqa: F821 + module.weight.data.fill_(1.0) EMU3_TEXT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 4049743328..f9bcf181c5 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -381,6 +381,8 @@ class GemmaPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GemmaRMSNorm): + module.weight.data.fill_(1.0) GEMMA_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 144a94ef33..ecbfedb2ad 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -426,6 +426,8 @@ class Gemma2PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Gemma2RMSNorm): + module.weight.data.fill_(1.0) GEMMA2_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 0988e2692a..7d0447a545 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -486,13 +486,7 @@ class Gemma3PreTrainedModel(PreTrainedModel): _supports_attention_backend = True def _init_weights(self, module): - # important: this ported version of Gemma2 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) + std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) @@ -502,6 +496,10 @@ class Gemma3PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Gemma3RMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Gemma3MultiModalProjector): + module.mm_input_projection_weight.data.zero_() GEMMA3_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 3f7292f13a..afcb72c202 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -548,13 +548,7 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel): ] def _init_weights(self, module): - # important: this ported version of Gemma2 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) + std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) @@ -564,6 +558,10 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Gemma3RMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Gemma3MultiModalProjector): + module.mm_input_projection_weight.data.zero_() class Gemma3TextModel(Gemma2Model): diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 2a41815d74..044a401402 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -399,6 +399,8 @@ class GlmPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GlmRMSNorm): + module.weight.data.fill_(1.0) GLM_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 8fb2a0abf5..6356356f3e 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -407,6 +407,8 @@ class Glm4PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Glm4RMSNorm): + module.weight.data.fill_(1.0) GLM4_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index c474ef3690..17db99b2a3 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -591,26 +591,22 @@ class GotOcr2PreTrainedModel(PreTrainedModel): _supports_static_cache = True def _init_weights(self, module): - # important: this ported version of GotOcr2 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/GotOcr2/tree/main/got_ocr2 should serve for that purpose - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, GotOcr2LayerNorm)): # noqa: F821 + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, GotOcr2VisionAttention): + if module.use_rel_pos: + module.rel_pos_h.data.zero_() + module.rel_pos_w.data.zero_() + elif isinstance(module, GotOcr2VisionEncoder): + if module.pos_embed is not None: + module.pos_embed.data.zero_() GOT_OCR2_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index 4b7e7f1adb..e3daafd81c 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -276,7 +276,23 @@ class GotOcr2CausalLMOutputWithPast(LlavaCausalLMOutputWithPast): class GotOcr2PreTrainedModel(LlavaPreTrainedModel): - pass + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, GotOcr2LayerNorm)): # noqa: F821 + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, GotOcr2VisionAttention): + if module.use_rel_pos: + module.rel_pos_h.data.zero_() + module.rel_pos_w.data.zero_() + elif isinstance(module, GotOcr2VisionEncoder): + if module.pos_embed is not None: + module.pos_embed.data.zero_() GOT_OCR2_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index e08624bf81..43e860a775 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -75,6 +75,9 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, GPTNeoXJapaneseAttention): + if module.dense_bias is not None: + module.dense_bias.data.zero_() class GPTNeoXJapaneseAttention(nn.Module): diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 80d3ad696d..445edab65b 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -366,6 +366,8 @@ class GranitePreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GraniteRMSNorm): + module.weight.data.fill_(1.0) class GraniteRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 821539d416..6ae4e6c35f 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -330,11 +330,15 @@ class GraniteSpeechPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, GraniteSpeechEncoderProjector): + module.query.data.normal_() GRANITE_SPEECH_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 9ea68ab0aa..6217888424 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -833,8 +833,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() + elif isinstance(module, GraniteMoeRMSNorm): module.weight.data.fill_(1.0) elif isinstance(module, GraniteMoeParallelExperts): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 886bed0968..787b805ee8 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -745,8 +745,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() + elif isinstance(module, GraniteMoeSharedRMSNorm): module.weight.data.fill_(1.0) elif isinstance(module, GraniteMoeSharedParallelExperts): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index d565af9e27..61bf2f2d09 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -384,6 +384,8 @@ class HeliumPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, HeliumRMSNorm): + module.weight.data.fill_(1.0) HELIUM_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index fb5006d2e2..770eb90ab1 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -44,7 +44,7 @@ from ...utils import ( ) from .configuration_idefics import IdeficsConfig from .perceiver import IdeficsPerceiverResampler -from .vision import IdeficsVisionTransformer +from .vision import IdeficsVisionEmbeddings, IdeficsVisionTransformer if is_torch_flex_attn_available(): @@ -934,7 +934,7 @@ class IdeficsPreTrainedModel(PreTrainedModel): # inference and fine-tuning - so the proper init weights code has been removed - the m4 code # base should be used for training from scratch and it contains the correct code. std = self.config.initializer_range - if isinstance(module, nn.Linear): + if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() @@ -942,6 +942,25 @@ class IdeficsPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, IdeficsRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, IdeficsVisionEmbeddings): + module.class_embedding.data.normal_() + elif isinstance(module, IdeficsGatedCrossAttentionLayer): + if self.config.alpha_initializer == "zeros": + module.alpha_cross_attn.data.zero_() + module.alpha_dense.data.zero_() + elif self.config.alpha_initializer == "ones": + module.alpha_cross_attn.data.fill_(1.0) + module.alpha_dense.data.fill_(1.0) + elif self.config.alpha_initializer in {"normal", "gaussian", "random"}: + module.alpha_cross_attn.data.normal_(mean=0.0, std=self.config.alphas_initializer_range) + module.alpha_dense.data.normal_(mean=0.0, std=self.config.alphas_initializer_range) + elif isinstance(module, IdeficsPerceiverResampler): + module.latents.data.normal_() LLAMA_INPUTS_DOCSTRING = r""" @@ -1495,7 +1514,6 @@ class IdeficsModel(IdeficsPreTrainedModel): class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] def __init__(self, config, vision_model=None): diff --git a/src/transformers/models/idefics2/configuration_idefics2.py b/src/transformers/models/idefics2/configuration_idefics2.py index 2f0376a895..31912d6ad9 100644 --- a/src/transformers/models/idefics2/configuration_idefics2.py +++ b/src/transformers/models/idefics2/configuration_idefics2.py @@ -130,6 +130,8 @@ class Idefics2PerceiverConfig(PretrainedConfig): Number of key-value heads in the perceiver attention block. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation for initializing all weight matrices in the model. """ model_type = "idefics2_perceiver" @@ -145,6 +147,7 @@ class Idefics2PerceiverConfig(PretrainedConfig): resampler_head_dim=96, num_key_value_heads=4, attention_dropout=0.0, + initializer_range=0.02, **kwargs, ): self.hidden_act = hidden_act @@ -156,6 +159,7 @@ class Idefics2PerceiverConfig(PretrainedConfig): self.num_key_value_heads = num_key_value_heads self.resampler_head_dim = resampler_head_dim self.attention_dropout = attention_dropout + self.initializer_range = initializer_range if self.num_key_value_heads > self.resampler_n_heads: raise ValueError( f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to" diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 16ff2873b1..d23085bd37 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -517,14 +517,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): _supports_cache_class = True def _init_weights(self, module): - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.get_text_config().initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) @@ -534,6 +527,17 @@ class Idefics2PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Idefics2RMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, nn.MultiheadAttention): + module._reset_parameters() # native torch init + elif isinstance(module, Idefics2MultiheadAttentionPoolingHead): + module.probe.data.normal_() + elif isinstance(module, Idefics2PerceiverResampler): + module.latents.data.fill_(1.0) IDEFICS2_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 64193c2a5d..5945fd71c5 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -533,16 +533,8 @@ class Idefics3PreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_cache_class = True - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2PreTrainedModel._init_weights def _init_weights(self, module): - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.get_text_config().initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) @@ -552,6 +544,11 @@ class Idefics3PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Idefics3RMSNorm): + module.weight.data.fill_(1.0) IDEFICS3_VISION_START_DOCSTRING = r""" diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index d685dd6e99..a304353cc4 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -323,26 +323,24 @@ class InstructBlipPreTrainedModel(PreTrainedModel): "InstructBlipQFormerSelfOutput", ] - # Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_range - if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear): + + if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=factor) - if hasattr(module, "bias") and module.bias is not None: + if module.bias is not None: module.bias.data.zero_() - - if isinstance(module, InstructBlipVisionEmbeddings): - if hasattr(self.config, "vision_config") and not isinstance(self.config, InstructBlipVisionConfig): - factor = self.config.vision_config.initializer_range - nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) - nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) - + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) - elif isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + elif isinstance(module, InstructBlipVisionEmbeddings): + nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) + nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) + elif isinstance(module, InstructBlipForConditionalGeneration): + module.query_tokens.data.zero_() INSTRUCTBLIP_START_DOCSTRING = r""" diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index cdc64c6802..d2a6c7b6f1 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -130,44 +130,6 @@ class InstructBlipVideoVisionEmbeddings(nn.Module): return embeddings -class InstructBlipVideoPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = InstructBlipVideoConfig - base_model_prefix = "blip" - supports_gradient_checkpointing = True - - _no_split_modules = [ - "InstructBlipVideoQFormerEmbeddings", - "InstructBlipVideoAttention", - "InstructBlipVideoQFormerMultiHeadAttention", - "InstructBlipVideoQFormerSelfOutput", - ] - - def _init_weights(self, module): - """Initialize the weights""" - factor = self.config.initializer_range - if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=factor) - if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() - - if isinstance(module, InstructBlipVideoVisionEmbeddings): - if hasattr(self.config, "vision_config") and not isinstance(self.config, InstructBlipVideoVisionConfig): - factor = self.config.vision_config.initializer_range - nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) - nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) - - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - class InstructBlipVideoAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -416,73 +378,6 @@ INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r""" """ -class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel): - main_input_name = "pixel_values" - config_class = InstructBlipVideoVisionConfig - - def __init__(self, config: InstructBlipVideoVisionConfig): - super().__init__(config) - self.config = config - embed_dim = config.hidden_size - - self.embeddings = InstructBlipVideoVisionEmbeddings(config) - self.encoder = InstructBlipVideoEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - - self.post_init() - - @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=InstructBlipVideoVisionConfig) - def forward( - self, - pixel_values: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - interpolate_pos_encoding: bool = False, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - Returns: - - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = self.post_layernorm(last_hidden_state) - - pooled_output = last_hidden_state[:, 0, :] - pooled_output = self.post_layernorm(pooled_output) - - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - def get_input_embeddings(self): - return self.embeddings - - class InstructBlipVideoQFormerMultiHeadAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() @@ -957,6 +852,194 @@ class InstructBlipVideoQFormerEmbeddings(nn.Module): return embeddings +INSTRUCTBLIPVIDEO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See + [`InstructBlipVideoProcessor.__call__`] for details. + + qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided + to serve as text prompt, which the Q-Former model will encode. + + Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be + provided to serve as text prompt, which the language model can continue. + + Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an + encoder-decoder language model (like T5) is used. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) + + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + Only relevant in case an encoder-decoder language model (like T5) is used. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). +""" + + +class InstructBlipVideoPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = InstructBlipVideoConfig + base_model_prefix = "blip" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "InstructBlipVideoQFormerEmbeddings", + "InstructBlipVideoAttention", + "InstructBlipVideoQFormerMultiHeadAttention", + "InstructBlipVideoQFormerSelfOutput", + ] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_range + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=factor) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, InstructBlipVideoVisionEmbeddings): + nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) + nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) + elif isinstance(module, InstructBlipVideoForConditionalGeneration): + module.query_tokens.data.zero_() + + +class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel): + main_input_name = "pixel_values" + config_class = InstructBlipVideoVisionConfig + + def __init__(self, config: InstructBlipVideoVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = InstructBlipVideoVisionEmbeddings(config) + self.encoder = InstructBlipVideoEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=InstructBlipVideoVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.embeddings + + class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel): """ Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the @@ -1186,90 +1269,6 @@ class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput): ) -INSTRUCTBLIPVIDEO_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See - [`InstructBlipVideoProcessor.__call__`] for details. - - qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided - to serve as text prompt, which the Q-Former model will encode. - - Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for - details. - - [What are input IDs?](../glossary#input-ids) - - qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be - provided to serve as text prompt, which the language model can continue. - - Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for - details. - - [What are input IDs?](../glossary#input-ids) - - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an - encoder-decoder language model (like T5) is used. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) - - decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - Only relevant in case an encoder-decoder language model (like T5) is used. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the pre-trained position encodings. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). -""" - - @add_start_docstrings( """ InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index abf4ad4b3d..c91cc6e299 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1115,6 +1115,13 @@ class JambaPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, JambaRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, JambaMambaMixer): + A = torch.arange(1, module.ssm_state_size + 1)[None, :] + A = A.expand(module.intermediate_size, -1).contiguous() + module.A_log.data.copy_(torch.log(A)) + module.D.data.fill_(1.0) JAMBA_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 70cb8126dc..44e0e5b809 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -856,8 +856,7 @@ class JetMoePreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() + elif isinstance(module, JetMoeRMSNorm): module.weight.data.fill_(1.0) elif isinstance(module, JetMoeParallelExperts): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d36fb1b6a4..eec1ecfee3 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -389,6 +389,8 @@ class LlamaPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, LlamaRMSNorm): + module.weight.data.fill_(1.0) LLAMA_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 1fed0c9ca2..37614db7c2 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -492,6 +492,17 @@ class Llama4PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Llama4TextRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Llama4TextExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.down_proj.data.normal_(mean=0.0, std=std) + elif isinstance(module, Llama4VisionModel): + module.class_embedding.data.normal_(std=module.scale) + module.positional_embedding_vlm.data.normal_(std=module.scale) LLAMA4_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index c1d075b641..a8b5f074d0 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -144,23 +144,12 @@ class LlavaPreTrainedModel(PreTrainedModel): # important: this ported version of Llava isn't meant for training from scratch - only # inference and fine-tuning - so the proper init weights code has been removed - the original codebase # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() LLAVA_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 06fc6bfedb..6301402e6e 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -236,7 +236,6 @@ LLAVA_NEXT_START_DOCSTRING = r""" "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAVA_NEXT_START_DOCSTRING, ) -# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->LlavaNext,llava->llava_next class LlavaNextPreTrainedModel(PreTrainedModel): config_class = LlavaNextConfig base_model_prefix = "model" @@ -250,26 +249,15 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _supports_static_cache = True def _init_weights(self, module): - # important: this ported version of LlavaNext isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + elif isinstance(module, LlavaNextForConditionalGeneration): + embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) + module.image_newline.data.normal_(mean=0.0, std=embed_std) LLAVA_NEXT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index bf30ff17c0..b4a9c899c9 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -129,62 +129,6 @@ class LlavaNextVideoPooler(nn.Module): return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() -LLAVA_NEXT_VIDEO_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlavaNextVideoConfig`] or [`LlavaNextVideoVisionConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAVA_NEXT_VIDEO_START_DOCSTRING, -) -class LlavaNextVideoPreTrainedModel(PreTrainedModel): - config_class = LlavaNextVideoConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlavaNextVideoVisionAttention"] - _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - # important: this ported version of LlavaNextVideo isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/LLaVA/tree/main/llava_next_video should serve for that purpose - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - class LlavaNextVideoMultiModalProjector(nn.Module): def __init__(self, config: LlavaNextVideoConfig): super().__init__() @@ -207,6 +151,23 @@ class LlavaNextVideoMultiModalProjector(nn.Module): return hidden_states +LLAVA_NEXT_VIDEO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlavaNextVideoConfig`] or [`LlavaNextVideoVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. @@ -394,6 +355,34 @@ LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r""" """ +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAVA_NEXT_VIDEO_START_DOCSTRING, +) +class LlavaNextVideoPreTrainedModel(PreTrainedModel): + config_class = LlavaNextVideoConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlavaNextVideoVisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, LlavaNextVideoForConditionalGeneration): + embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) + module.image_newline.data.normal_(mean=0.0, std=embed_std) + + @add_start_docstrings( """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", LLAVA_NEXT_VIDEO_START_DOCSTRING, diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 8168682ad7..0109082d84 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -24,6 +24,7 @@ from torch import nn from transformers.models.llava_next.modeling_llava_next import ( LlavaNextCausalLMOutputWithPast, LlavaNextForConditionalGeneration, + LlavaNextMultiModalProjector, LlavaNextPreTrainedModel, image_size_to_num_patches, ) @@ -222,10 +223,23 @@ class LlavaNextVideoPooler(nn.Module): return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() -class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel): +class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector): pass +class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel): + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, LlavaNextVideoForConditionalGeneration): + embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) + module.image_newline.data.normal_(mean=0.0, std=embed_std) + + class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): def __init__(self, config: LlavaNextVideoConfig, **super_kwargs): super().__init__(config, **super_kwargs) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 31d5b9edb6..3f77d39c02 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -255,28 +255,17 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _supports_quantized_cache = True _supports_sdpa = True - # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights + # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights with LlavaNext->LlavaOnevision def _init_weights(self, module): - # important: this ported version of LlavaNext isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + elif isinstance(module, LlavaOnevisionForConditionalGeneration): + embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) + module.image_newline.data.normal_(mean=0.0, std=embed_std) LLAVA_ONEVISION_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 317f5ae01b..02aedb8c01 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1412,31 +1412,22 @@ class MimiPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_static_cache = True - # Copied from transformers.models.encodec.modeling_encodec.EncodecPreTrainedModel._init_weights def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) - elif isinstance(module, nn.Conv1d): + elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LSTM): - for name, param in module.named_parameters(): - if "weight" in name: - nn.init.xavier_uniform_(param) - elif "bias" in name: - nn.init.constant_(param, 0.0) + elif isinstance(module, MimiLayerScale): + module.scale.data.fill_(self.config.layer_scale_initial_scale) MIMI_START_DOCSTRING = r""" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 2ee9f9ccd7..7de6cad370 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -318,6 +318,8 @@ class MistralPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) class MistralRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 6d60202db4..08a8b7b315 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -203,26 +203,14 @@ class Mistral3PreTrainedModel(PreTrainedModel): _supports_static_cache = True def _init_weights(self, module): - # important: this ported version of Mistral3 isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/Mistral3/tree/main/mistral3 should serve for that purpose - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Mistral3RMSNorm): + module.weight.data.fill_(1.0) MISTRAL3_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index 3793bef183..5eebcd8d56 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -20,7 +20,7 @@ from torch import nn from ...activations import ACT2FN from ...utils import is_torchdynamo_compiling, logging -from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration +from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel from ..mistral.modeling_mistral import MistralRMSNorm from .configuration_mistral3 import Mistral3Config @@ -100,6 +100,18 @@ class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast): pass +class Mistral3PreTrainedModel(LlavaPreTrainedModel): + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, Mistral3RMSNorm): + module.weight.data.fill_(1.0) + + class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): def get_image_features( self, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 3075eb01f6..001692bd75 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -474,6 +474,8 @@ class MixtralPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MixtralRMSNorm): + module.weight.data.fill_(1.0) MIXTRAL_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 64329c2abd..a2911a46a0 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1029,7 +1029,8 @@ class MllamaPreTrainedModel(PreTrainedModel): _supports_quantized_cache = True def _init_weights(self, module): - std = self.config.get_text_config().initializer_range + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: @@ -1038,15 +1039,25 @@ class MllamaPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.Parameter): - module.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, MllamaTextRMSNorm): + module.weight.data.fill_(1.0) elif isinstance(module, MllamaVisionModel): nn.init.normal_(module.class_embedding.data, std=std) elif isinstance(module, MllamaPrecomputedPositionEmbedding): nn.init.normal_(module.embedding.data, std=std) + nn.init.zeros_(module.gate.data) elif isinstance(module, MllamaVisionEncoderLayer) and module.is_gated: nn.init.normal_(module.gate_attn.data, std=std) nn.init.normal_(module.gate_ffn.data, std=std) + elif isinstance(module, MllamaCrossAttentionDecoderLayer): + module.cross_attn_attn_gate.data.zero_() + module.cross_attn_mlp_gate.data.zero_() + elif isinstance(module, MllamaPrecomputedAspectRatioEmbedding): + if module.is_gated: + module.gate.data.zero_() # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index c9e6eb756c..aaeb405a0e 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -536,6 +536,10 @@ class MoonshinePreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)): + module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 9d6f2c52c5..d76941c8d6 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -554,6 +554,10 @@ class MoonshinePreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)): + module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 609abb1deb..dc94efd353 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -849,22 +849,19 @@ class MoshiPreTrainedModel(PreTrainedModel): def _init_weights(self, module): std = self.config.initializer_range - if isinstance(module, (nn.Linear, nn.Conv1d)): + + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, nn.Conv1d): - nn.init.kaiming_normal_(module.weight) - if module.bias is not None: - k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) - nn.init.uniform_(module.bias, a=-k, b=k) + elif isinstance(module, MoshiFlexibleLinear): + module.weight.data.normal_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MoshiRMSNorm): + module.weight.data.fill_(1.0) MOSHI_START_DOCSTRING = r""" diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 31d14fb5da..a17833912f 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -623,6 +623,9 @@ class NemotronPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, NemotronLayerNorm1P): + module.weight.data.fill_(1.0) + module.bias.data.zero_() NEMOTRON_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 5b6ca9f4b3..bf5e80b839 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -282,6 +282,40 @@ class OlmoDecoderLayer(nn.Module): return outputs +class OlmoRotaryEmbedding(nn.Module): + def __init__(self, config: OlmoConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + OLMO_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -329,40 +363,6 @@ class OlmoPreTrainedModel(PreTrainedModel): module.weight.data[module.padding_idx].zero_() -class OlmoRotaryEmbedding(nn.Module): - def __init__(self, config: OlmoConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - OLMO_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index ca290ba9ae..a2af171557 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -15,6 +15,7 @@ from ..llama.modeling_llama import ( LlamaMLP, LlamaModel, LlamaPreTrainedModel, + LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, ) @@ -114,10 +115,23 @@ class OlmoDecoderLayer(LlamaDecoderLayer): self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) -class OlmoPreTrainedModel(LlamaPreTrainedModel): +class OlmoRotaryEmbedding(LlamaRotaryEmbedding): pass +class OlmoPreTrainedModel(LlamaPreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + class OlmoModel(LlamaModel): def __init__(self, config: OlmoConfig): super().__init__(config) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 4046dc5826..e44ea5f62b 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -286,6 +286,40 @@ class Olmo2DecoderLayer(nn.Module): return outputs +class Olmo2RotaryEmbedding(nn.Module): + def __init__(self, config: Olmo2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + OLMO2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -331,40 +365,8 @@ class Olmo2PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - - -class Olmo2RotaryEmbedding(nn.Module): - def __init__(self, config: Olmo2Config, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + elif isinstance(module, Olmo2RMSNorm): + module.weight.data.fill_(1.0) OLMO2_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index fbd431532f..c43263e954 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -7,13 +7,14 @@ from ...cache_utils import Cache from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import logging -from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward +from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward from ..olmo.configuration_olmo import OlmoConfig from ..olmo.modeling_olmo import ( OlmoAttention, OlmoDecoderLayer, OlmoForCausalLM, OlmoModel, + OlmoRotaryEmbedding, apply_rotary_pos_emb, ) @@ -287,6 +288,14 @@ class Olmo2DecoderLayer(OlmoDecoderLayer): return outputs +class Olmo2RotaryEmbedding(OlmoRotaryEmbedding): + pass + + +class Olmo2PreTrainedModel(LlamaPreTrainedModel): + pass + + # The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of # standard layer norm for the output norm. class Olmo2Model(OlmoModel): diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index ce38e57cfd..007da568f0 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -747,6 +747,8 @@ class OlmoePreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, OlmoeRMSNorm): + module.weight.data.fill_(1.0) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 5a1f15f94c..d0175055cc 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -511,6 +511,9 @@ class OPTPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() OPT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index ef92378e0b..f13bd0feef 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -199,23 +199,12 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): # important: this ported version of PaliGemmaisn't meant for training from scratch - only # inference and fine-tuning - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() PALIGEMMA_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 029e376eaf..da8a8d2927 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -412,6 +412,9 @@ class PersimmonPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() PERSIMMON_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index ffb36ed45f..8572b1546a 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -279,6 +279,40 @@ class PhiDecoderLayer(nn.Module): return outputs +class PhiRotaryEmbedding(nn.Module): + def __init__(self, config: PhiConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + PHI_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -324,40 +358,9 @@ class PhiPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - - -class PhiRotaryEmbedding(nn.Module): - def __init__(self, config: PhiConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() PHI_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 5661432f8d..5faee931e0 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -20,6 +20,7 @@ from ..llama.modeling_llama import ( LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, + LlamaRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, # copied from Llama ) @@ -170,10 +171,26 @@ class PhiDecoderLayer(nn.Module): return outputs -class PhiPreTrainedModel(LlamaPreTrainedModel): +class PhiRotaryEmbedding(LlamaRotaryEmbedding): pass +class PhiPreTrainedModel(LlamaPreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + class PhiModel(LlamaModel): def __init__(self, config: PhiConfig): super().__init__(config) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index ac36fa5e21..1737d8c3df 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -373,6 +373,8 @@ class Phi3PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Phi3RMSNorm): + module.weight.data.fill_(1.0) class Phi3RotaryEmbedding(nn.Module): diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index d0a6b3816c..8677ef9dc9 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1030,6 +1030,9 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv): + module.b1.data.zero_() + module.b2.data.zero_() def unfold_tensor(tensor, max_seq_len): @@ -1607,6 +1610,40 @@ class Phi4MultimodalFeatureEmbedding(nn.Module): return inputs_embeds +class Phi4MultimodalRotaryEmbedding(nn.Module): + def __init__(self, config: Phi4MultimodalConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + PHI4_MULTIMODAL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -1653,40 +1690,11 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - - -class Phi4MultimodalRotaryEmbedding(nn.Module): - def __init__(self, config: Phi4MultimodalConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + elif isinstance(module, Phi4MultimodalRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Phi4MultimodalImageEmbedding): + module.global_img_feature_extensor.data.zero_() + module.sub_img_feature_extensor.data.zero_() PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 901cfa27b0..d269b06037 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -40,7 +40,14 @@ from ...utils import ( replace_return_docstrings, ) from ..phi3.configuration_phi3 import Phi3Config -from ..phi3.modeling_phi3 import Phi3DecoderLayer, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm +from ..phi3.modeling_phi3 import ( + Phi3DecoderLayer, + Phi3ForCausalLM, + Phi3Model, + Phi3PreTrainedModel, + Phi3RMSNorm, + Phi3RotaryEmbedding, +) from ..siglip.configuration_siglip import SiglipVisionConfig from ..siglip.modeling_siglip import ( SiglipEncoder, @@ -1133,6 +1140,9 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv): + module.b1.data.zero_() + module.b2.data.zero_() class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): @@ -1519,6 +1529,28 @@ PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r""" """ +class Phi4MultimodalRotaryEmbedding(Phi3RotaryEmbedding): + pass + + +class Phi4MultimodalPreTrainedModel(Phi3PreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Phi4MultimodalRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Phi4MultimodalImageEmbedding): + module.global_img_feature_extensor.data.zero_() + module.sub_img_feature_extensor.data.zero_() + + class Phi4MultimodalModel(Phi3Model, nn.Module): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`] @@ -1829,7 +1861,7 @@ __all__ = [ "Phi4MultimodalAudioModel", "Phi4MultimodalVisionPreTrainedModel", "Phi4MultimodalVisionModel", - "Phi4MultimodalPreTrainedModel", # noqa + "Phi4MultimodalPreTrainedModel", "Phi4MultimodalModel", "Phi4MultimodalForCausalLM", "Phi4MultimodalVisionConfig", diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index a4dee0a6f7..ab8370d8ec 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -923,6 +923,9 @@ class PhimoePreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) PHIMOE_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 1a3e3aee0f..bc0ec918e9 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -383,20 +383,13 @@ class PixtralPreTrainedModel(PreTrainedModel): _no_split_modules = ["PixtralAttentionLayer"] def _init_weights(self, module): - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.initializer_range - ) - + std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + elif isinstance(module, PixtralRMSNorm): + module.weight.data.fill_(1.0) PIXTRAL_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py index 7653a72f06..7a1e2c7e47 100644 --- a/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py @@ -254,15 +254,10 @@ class PromptDepthAnythingPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 + if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) class PromptDepthAnythingReassembleLayer(nn.Module): diff --git a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py index ad9b254a8a..aa834339ea 100644 --- a/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +++ b/src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py @@ -210,15 +210,10 @@ class PromptDepthAnythingPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 + if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) class PromptDepthAnythingReassembleLayer(nn.Module): diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index ca85bacbb2..147f654652 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -331,6 +331,8 @@ class Qwen2PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) class Qwen2RotaryEmbedding(nn.Module): diff --git a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py index 736736527f..d079846708 100644 --- a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py @@ -92,6 +92,7 @@ class Qwen2_5OmniVisionEncoderConfig(PretrainedConfig): window_size=112, out_hidden_size=3584, fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) @@ -108,6 +109,7 @@ class Qwen2_5OmniVisionEncoderConfig(PretrainedConfig): self.window_size = window_size self.fullatt_block_indexes = fullatt_block_indexes self.out_hidden_size = out_hidden_size + self.initializer_range = initializer_range class Qwen2_5OmniAudioEncoderConfig(PretrainedConfig): diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 464bf2bd8a..4c71c431ac 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -75,6 +75,26 @@ if is_flash_attn_available(): logger = logging.get_logger(__name__) +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + Qwen2_5Omni_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -112,7 +132,7 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): # inference and fine-tuning - so the proper init weights code has been removed std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d)): + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() @@ -120,6 +140,11 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel): @@ -1102,26 +1127,6 @@ class Qwen2_5OmniMLP(nn.Module): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = { "eager": Qwen2_5OmniVisionAttention, "flash_attention_2": Qwen2_5OmniVisionFlashAttention2, diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 55630f70a3..9bb81ddcc5 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -36,6 +36,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLModel, Qwen2_5_VLPreTrainedModel, Qwen2_5_VLVisionBlock, + Qwen2RMSNorm, ) from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioEncoderConfig from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer @@ -130,6 +131,7 @@ class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig): window_size=112, out_hidden_size=3584, fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, **kwargs, ): super().__init__( @@ -145,6 +147,7 @@ class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig): window_size, out_hidden_size, fullatt_block_indexes, + initializer_range=initializer_range, **kwargs, ) del self.tokens_per_second @@ -1027,7 +1030,7 @@ class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): # inference and fine-tuning - so the proper init weights code has been removed std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d)): + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() @@ -1035,6 +1038,11 @@ class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel): diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index ed3505728a..63ca1c2359 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -46,6 +46,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig): window_size=112, out_hidden_size=3584, fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) @@ -63,6 +64,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig): self.window_size = window_size self.fullatt_block_indexes = fullatt_block_indexes self.out_hidden_size = out_hidden_size + self.initializer_range = initializer_range class Qwen2_5_VLConfig(PretrainedConfig): diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 6e1bfb0e5f..5f0a9d003f 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -388,6 +388,8 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 925d59df2f..fa245c45f5 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -89,6 +89,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig): window_size=112, out_hidden_size=3584, fullatt_block_indexes=[7, 15, 23, 31], + initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) @@ -106,6 +107,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig): self.window_size = window_size self.fullatt_block_indexes = fullatt_block_indexes self.out_hidden_size = out_hidden_size + self.initializer_range = initializer_range class Qwen2_5_VLConfig(Qwen2VLConfig): @@ -224,7 +226,18 @@ class Qwen2_5_VLVisionBlock(nn.Module): class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): - pass + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 40be54033c..6c9cc40ad8 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -779,6 +779,8 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen2MoeRMSNorm): + module.weight.data.fill_(1.0) QWEN2MOE_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index 2917e2d8ba..b03dbc8f0b 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -38,6 +38,7 @@ class Qwen2VLVisionConfig(PretrainedConfig): patch_size=14, spatial_merge_size=2, temporal_patch_size=2, + initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) @@ -52,6 +53,7 @@ class Qwen2VLVisionConfig(PretrainedConfig): self.patch_size = patch_size self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size + self.initializer_range = initializer_range class Qwen2VLConfig(PretrainedConfig): diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index f92402c92e..e172f092d7 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -914,6 +914,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): def _init_weights(self, module): std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: @@ -922,6 +923,11 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index f98a8a2759..89f30e78f4 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -358,6 +358,8 @@ class Qwen3PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen3RMSNorm): + module.weight.data.fill_(1.0) class Qwen3RotaryEmbedding(nn.Module): diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 4196693d19..3462f565e2 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -488,6 +488,8 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen3MoeRMSNorm): + module.weight.data.fill_(1.0) QWEN3_MOE_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 3bf5b43706..2e2afdbc01 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -581,6 +581,13 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): torch.nn.init.normal_(module.weight, mean=0.0, std=std) if getattr(module, "bias", None) is not None: torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + elif isinstance(module, RecurrentGemmaRMSNorm): + module.weight.data.fill_(1.0) def _setup_cache(self, config, batch, device, dtype): layers = getattr(self, "model", self).layers diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 727925fb17..6ecf72d0a1 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -1040,8 +1040,6 @@ class RTDetrPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initalize the weights""" - - """initialize linear layer bias value according to a given probability value.""" if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)): if module.class_embed is not None: for layer in module.class_embed: @@ -1055,7 +1053,7 @@ class RTDetrPreTrainedModel(PreTrainedModel): nn.init.constant_(layer.layers[-1].weight, 0) nn.init.constant_(layer.layers[-1].bias, 0) - if isinstance(module, RTDetrMultiscaleDeformableAttention): + elif isinstance(module, RTDetrMultiscaleDeformableAttention): nn.init.constant_(module.sampling_offsets.weight.data, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( @@ -1078,17 +1076,21 @@ class RTDetrPreTrainedModel(PreTrainedModel): nn.init.xavier_uniform_(module.output_proj.weight.data) nn.init.constant_(module.output_proj.bias.data, 0.0) - if isinstance(module, RTDetrModel): + elif isinstance(module, RTDetrModel): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) nn.init.xavier_uniform_(module.enc_score_head.weight) nn.init.constant_(module.enc_score_head.bias, bias) - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: nn.init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: diff --git a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py index 59a00a6e74..9505c3fdd3 100644 --- a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +++ b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -1314,8 +1314,6 @@ class RTDetrV2PreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initalize the weights""" - - """initialize linear layer bias value according to a given probability value.""" if isinstance(module, (RTDetrV2ForObjectDetection, RTDetrV2Decoder)): if module.class_embed is not None: for layer in module.class_embed: @@ -1329,7 +1327,7 @@ class RTDetrV2PreTrainedModel(PreTrainedModel): nn.init.constant_(layer.layers[-1].weight, 0) nn.init.constant_(layer.layers[-1].bias, 0) - if isinstance(module, RTDetrV2MultiscaleDeformableAttention): + elif isinstance(module, RTDetrV2MultiscaleDeformableAttention): nn.init.constant_(module.sampling_offsets.weight.data, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( @@ -1352,17 +1350,21 @@ class RTDetrV2PreTrainedModel(PreTrainedModel): nn.init.xavier_uniform_(module.output_proj.weight.data) nn.init.constant_(module.output_proj.bias.data, 0.0) - if isinstance(module, RTDetrV2Model): + elif isinstance(module, RTDetrV2Model): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) nn.init.xavier_uniform_(module.enc_score_head.weight) nn.init.constant_(module.enc_score_head.bias, bias) - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: nn.init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index e88e23776b..6c33479e91 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -80,14 +80,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel): _supports_cache_class = True def _init_weights(self, module): - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.get_text_config().initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) @@ -97,6 +90,9 @@ class SmolVLMPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() class SmolVLMVisionEmbeddings(nn.Module): diff --git a/src/transformers/models/smolvlm/modular_smolvlm.py b/src/transformers/models/smolvlm/modular_smolvlm.py index 051bdfaf5a..4745fe30da 100644 --- a/src/transformers/models/smolvlm/modular_smolvlm.py +++ b/src/transformers/models/smolvlm/modular_smolvlm.py @@ -94,7 +94,20 @@ class SmolVLMVisionConfig(Idefics3VisionConfig): class SmolVLMPreTrainedModel(Idefics3PreTrainedModel): - pass + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() class SmolVLMVisionTransformer(Idefics3VisionTransformer): diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 69d2ff3c0d..69beb543b4 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -666,6 +666,9 @@ class StableLmPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() STABLELM_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 64ffd865a5..569874aad1 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -275,6 +275,40 @@ class Starcoder2DecoderLayer(nn.Module): return outputs +class Starcoder2RotaryEmbedding(nn.Module): + def __init__(self, config: Starcoder2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + STARCODER2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -320,40 +354,9 @@ class Starcoder2PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - - -class Starcoder2RotaryEmbedding(nn.Module): - def __init__(self, config: Starcoder2Config, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() STARCODER2_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 1aaf789a50..77b05d1dc1 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -41,6 +41,8 @@ from ..mistral.modeling_mistral import ( MistralForSequenceClassification, MistralForTokenClassification, MistralModel, + MistralPreTrainedModel, + MistralRotaryEmbedding, apply_rotary_pos_emb, eager_attention_forward, ) @@ -143,6 +145,26 @@ class Starcoder2DecoderLayer(MistralDecoderLayer): self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) +class Starcoder2RotaryEmbedding(MistralRotaryEmbedding): + pass + + +class Starcoder2PreTrainedModel(MistralPreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + STARCODER2_INPUTS_DOCSTRING = None # will be automatically redefined diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index 20b7c2a20e..0196169c37 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -166,15 +166,6 @@ class UperNetHead(nn.Module): padding=1, ) - def init_weights(self): - self.apply(self._init_weights) - - def _init_weights(self, module): - if isinstance(module, nn.Conv2d): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - def psp_forward(self, inputs): x = inputs[-1] psp_outs = [x] @@ -266,15 +257,6 @@ class UperNetFCNHead(nn.Module): self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) - def init_weights(self): - self.apply(self._init_weights) - - def _init_weights(self, module): - if isinstance(module, nn.Conv2d): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: # just take the relevant feature maps hidden_states = encoder_hidden_states[self.in_index] @@ -296,18 +278,13 @@ class UperNetPreTrainedModel(PreTrainedModel): _no_split_modules = [] def _init_weights(self, module): - if isinstance(module, UperNetPreTrainedModel): - module.backbone.init_weights() - module.decode_head.init_weights() - if module.auxiliary_head is not None: - module.auxiliary_head.init_weights() - - def init_weights(self): - """Initialize the weights""" - self.backbone.init_weights() - self.decode_head.init_weights() - if self.auxiliary_head is not None: - self.auxiliary_head.init_weights() + if isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.BatchNorm2d): + module.weight.data.fill_(1.0) + module.bias.data.zero_() UPERNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 3c706e43d5..71e5b9498b 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -128,7 +128,6 @@ VIPLLAVA_START_DOCSTRING = r""" "The bare VipLlava Model outputting raw hidden-states without any specific head on top.", VIPLLAVA_START_DOCSTRING, ) -# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->VipLlava,llava->vipllava class VipLlavaPreTrainedModel(PreTrainedModel): config_class = VipLlavaConfig base_model_prefix = "model" @@ -142,26 +141,15 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _supports_static_cache = True def _init_weights(self, module): - # important: this ported version of VipLlava isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - the original codebase - # https://github.com/haotian-liu/LLaVA/tree/main/vipllava should serve for that purpose - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): + if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) VIPLLAVA_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index dec1e03937..2cb995586a 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -786,10 +786,14 @@ class WhisperPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() elif isinstance(module, WhisperEncoder): - with torch.no_grad(): - embed_positions = module.embed_positions.weight - embed_positions.copy_(sinusoids(*embed_positions.shape)) + module.embed_positions.weight.copy_(sinusoids(*module.embed_positions.weight.shape)) + elif isinstance(module, WhisperForAudioClassification): + if self.config.use_weighted_layer_sum: + module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1)) def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 27d75129e8..2fe1720c42 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -850,10 +850,9 @@ class ZambaPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, ZambaRMSNorm): + module.weight.data.fill_(1.0) elif isinstance(module, ZambaMambaMixer): - module.A_log._no_weight_decay = True - module.D._no_weight_decay = True - module.x_proj_weight.data.normal_(mean=0.0, std=std) dt_init_std = self.config.mamba_dt_rank**-0.5 nn.init.uniform_(module.dt_proj_weight, -dt_init_std, dt_init_std) @@ -866,10 +865,12 @@ class ZambaPreTrainedModel(PreTrainedModel): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) + module.dt_proj_bias.data.copy_(inv_dt) - with torch.no_grad(): - module.dt_proj_bias.copy_(inv_dt) - module.dt_proj_bias._no_reinit = True + A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(module.intermediate_size, -1).contiguous() + module.A_log.data.copy_(torch.log(A).reshape(module.n_mamba_heads, module.mamba_head_dim, -1)) + module.D.data.fill_(1.0) @classmethod @classmethod diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 843420d7db..a3735303ec 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1225,10 +1225,9 @@ class Zamba2PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (Zamba2RMSNorm, Zamba2RMSNormGated)): + module.weight.data.fill_(1.0) elif isinstance(module, Zamba2MambaMixer): - module.A_log._no_weight_decay = True - module.D._no_weight_decay = True - dt = torch.exp( torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) @@ -1236,10 +1235,11 @@ class Zamba2PreTrainedModel(PreTrainedModel): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) + module.dt_bias.data.copy_(inv_dt) - with torch.no_grad(): - module.dt_bias.copy_(inv_dt) - module.dt_bias._no_reinit = True + A = torch.arange(1, module.num_heads + 1) + module.A_log.data.copy_(torch.log(A)) + module.D.data.fill_(1.0) ZAMBA2_START_DOCSTRING = r""" diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index ece5fb1065..2c672bba5a 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -930,10 +930,9 @@ class Zamba2PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (Zamba2RMSNorm, Zamba2RMSNormGated)): + module.weight.data.fill_(1.0) elif isinstance(module, Zamba2MambaMixer): - module.A_log._no_weight_decay = True - module.D._no_weight_decay = True - dt = torch.exp( torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) @@ -941,10 +940,11 @@ class Zamba2PreTrainedModel(PreTrainedModel): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) + module.dt_bias.data.copy_(inv_dt) - with torch.no_grad(): - module.dt_bias.copy_(inv_dt) - module.dt_bias._no_reinit = True + A = torch.arange(1, module.num_heads + 1) + module.A_log.data.copy_(torch.log(A)) + module.D.data.fill_(1.0) class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): diff --git a/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py b/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py index e7f2ad6a37..ee1c4c5b71 100644 --- a/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py +++ b/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py @@ -85,6 +85,7 @@ class Phi4MultimodalModelTester: intermediate_size=48, depthwise_seperable_out_channel=128, nemo_conv_channels=128, + initializer_range=1e-5, ), vision_config=Phi4MultimodalVisionConfig( num_hidden_layers=2, @@ -92,6 +93,7 @@ class Phi4MultimodalModelTester: intermediate_size=64, num_attention_heads=8, crop_size=16, + initializer_range=1e-5, ), ): self.parent = parent diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 58d1e46220..7aa8134e5d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -503,6 +503,76 @@ class ModelTesterMixin: m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" ) + def test_can_init_all_missing_weights(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # This is used to get the addition year of the model + filename = inspect.getfile(config.__class__) + # No easy way to get model addition date -> check copyright year on top of file + with open(filename) as file: + source_code = file.read() + addition_year = 0 # if we cannot find it, set it to 0 (i.e. oldest) + if match_object := re.search(r"^# Copyright (\d{4})", source_code, re.MULTILINE | re.IGNORECASE): + addition_year = int(match_object.group(1)) + + for model_class in self.all_model_classes: + # For now, skip everything older than 2025 and "important models" (too much models to patch otherwise) + # Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them + # TODO: relax this as we patch more and more models + if addition_year < 2025 and not model_class._supports_cache_class: + self.skipTest(reason=f"{model_class} is not a priorited model for now.") + + # Monkey patch the method to add a seed (we do it on PreTrainedModel._initialize_weights, which wraps + # `_init_weights` so that it can add the seed for composite models as well) + original_initialize_weights = PreTrainedModel._initialize_weights + + def seeded_initialize_weights(self, module): + set_seed(0) + original_initialize_weights(self, module) + + PreTrainedModel._initialize_weights = seeded_initialize_weights + + # First, initialize the model from config -> this ensure everything is correctly initialized, even if + # _init_weights() does not take all weights into account correctly + model_from_config = model_class(config) + # Here, passing an empty state dict will force all weights to be moved from meta to cpu, then be initialized + # by _init_weights() + model_from_pretrained = model_class.from_pretrained(None, config=config, state_dict={}) + + # Back to original method to avoid issues if running several other tests + PreTrainedModel._initialize_weights = original_initialize_weights + + # First, check if any parameters are still on meta -> this is usually an issue with tied weights + params_on_meta = [] + for k, v in model_from_pretrained.named_parameters(): + if v.device.type == "meta": + params_on_meta.append(k) + + self.assertTrue( + len(params_on_meta) == 0, + f"The following keys are still on the meta device, it probably comes from an issue in the tied weights:\n{params_on_meta}", + ) + + # Everything must be exactly the same as we set the same seed for each init + different_weights = [] + for (k1, v1), (k2, v2) in zip( + model_from_config.state_dict().items(), model_from_pretrained.state_dict().items() + ): + self.assertEqual(k1, k2, "The keys from each model should be the same") + # Since we added the seed, they should be exactly the same (i.e. using allclose maybe be wrong due + # to very low std in init function) + if not (v1 == v2).all(): + different_weights.append(k1) + + # Buffers that are initialized randomly are ignored as they are not initialized on meta device anyway + buffer_names = {name for name, _ in model_from_config.named_buffers()} + different_weights = [k for k in different_weights if k not in buffer_names] + + self.assertTrue( + len(different_weights) == 0, + f"The following keys are not properly handled by `_init_weights()`:\n{different_weights}", + ) + @slow @require_accelerate @mark.accelerate_tests