Detect and fix most _init_weights() issues - make it work for composite models (#37070)
* Update test_modeling_common.py * Fix Llama and its modular children * Update test_modeling_common.py * qwen3 * first try at prioritizing models * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * test * fix * fix * more models * more * more * more * smarter init for composite models! * fix post rebase * smol * fix missing args * more * typo * Super elegant and efficient init for submodels * Update modeling_utils.py * style * last fixes * cleanup * finalize cleanup * CIs * improve docstring * Update modeling_utils.py * llama4 * style * CIs * style * add dpt * granite speech * qwen 2.5 omni * better fix * Parse the config file instead * CIs
This commit is contained in:
@@ -679,12 +679,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, AriaTextRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, AriaGroupedExpertsGemm):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
ARIA_TEXT_START_DOCSTRING = r"""
|
||||
@@ -724,14 +722,17 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# This uses torch's original init
|
||||
module._reset_parameters()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, AriaProjector):
|
||||
nn.init.trunc_normal_(module.query, std=std)
|
||||
|
||||
|
||||
@@ -1255,12 +1255,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, AriaTextRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, AriaGroupedExpertsGemm):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
@@ -1269,14 +1267,17 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# This uses torch's original init
|
||||
module._reset_parameters()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, AriaProjector):
|
||||
nn.init.trunc_normal_(module.query, std=std)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user