General weight initialization scheme (#39579)

* general + modulars from llama

* all modular models

* style and fix musicgen

* fix

* Update configuration_musicgen.py

* Update modeling_utils.py
This commit is contained in:
Cyril Vallez
2025-07-22 16:04:20 +02:00
committed by GitHub
parent 015b62bf3e
commit b16688e96a
118 changed files with 205 additions and 1566 deletions

View File

@@ -2967,12 +2967,41 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
def _init_weights(self, module):
"""
Initialize the weights. This method should be overridden by derived class and is
the only initialization method that will be called when loading a checkpoint
using `from_pretrained`. Any attempt to initialize outside of this function
will be useless as the torch.nn.init function are all replaced with skip.
Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex
initialization scheme, it should be overriden by the derived `PreTrainedModel` class. In case a model adds an explicit
`nn.Parameter`, this method should also be overriden in order to initialize it correctly.
"""
pass
if hasattr(self.config, "initializer_range"):
std = self.config.initializer_range
else:
# 0.02 is the standard default value accross the library
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
# between modelings (because they are prefixed with the model name)
elif (
isinstance(
module, (nn.LayerNorm, nn.RMSNorm, nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
)
or "LayerNorm" in module.__class__.__name__
or "RMSNorm" in module.__class__.__name__
):
# Norms can exist without weights (in which case they are None from torch primitives)
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(1.0)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
def _initialize_weights(self, module):
"""