Detect and fix most _init_weights() issues - make it work for composite models (#37070)

* Update test_modeling_common.py

* Fix Llama and its modular children

* Update test_modeling_common.py

* qwen3

* first try at prioritizing models

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* test

* fix

* fix

* more models

* more

* more

* more

* smarter init for composite models!

* fix post rebase

* smol

* fix missing args

* more

* typo

* Super elegant and efficient init for submodels

* Update modeling_utils.py

* style

* last fixes

* cleanup

* finalize cleanup

* CIs

* improve docstring

* Update modeling_utils.py

* llama4

* style

* CIs

* style

* add dpt

* granite speech

* qwen 2.5 omni

* better fix

* Parse the config file instead

* CIs
This commit is contained in:
Cyril Vallez
2025-04-14 16:19:04 +02:00
committed by GitHub
parent 1897a02d83
commit 4e53840920
103 changed files with 1164 additions and 795 deletions

View File

@@ -679,12 +679,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, AriaTextRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, AriaGroupedExpertsGemm):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
ARIA_TEXT_START_DOCSTRING = r"""
@@ -724,14 +722,17 @@ class AriaPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, AriaProjector):
nn.init.trunc_normal_(module.query, std=std)

View File

@@ -1255,12 +1255,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, AriaTextRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, AriaGroupedExpertsGemm):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
class AriaPreTrainedModel(LlamaPreTrainedModel):
@@ -1269,14 +1267,17 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, AriaProjector):
nn.init.trunc_normal_(module.query, std=std)