General weight initialization scheme (#39579)
* general + modulars from llama * all modular models * style and fix musicgen * fix * Update configuration_musicgen.py * Update modeling_utils.py
This commit is contained in:
@@ -2967,12 +2967,41 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""
|
||||
Initialize the weights. This method should be overridden by derived class and is
|
||||
the only initialization method that will be called when loading a checkpoint
|
||||
using `from_pretrained`. Any attempt to initialize outside of this function
|
||||
will be useless as the torch.nn.init function are all replaced with skip.
|
||||
Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex
|
||||
initialization scheme, it should be overriden by the derived `PreTrainedModel` class. In case a model adds an explicit
|
||||
`nn.Parameter`, this method should also be overriden in order to initialize it correctly.
|
||||
"""
|
||||
pass
|
||||
if hasattr(self.config, "initializer_range"):
|
||||
std = self.config.initializer_range
|
||||
else:
|
||||
# 0.02 is the standard default value accross the library
|
||||
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# This uses torch's original init
|
||||
module._reset_parameters()
|
||||
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
|
||||
# between modelings (because they are prefixed with the model name)
|
||||
elif (
|
||||
isinstance(
|
||||
module, (nn.LayerNorm, nn.RMSNorm, nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
|
||||
)
|
||||
or "LayerNorm" in module.__class__.__name__
|
||||
or "RMSNorm" in module.__class__.__name__
|
||||
):
|
||||
# Norms can exist without weights (in which case they are None from torch primitives)
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(1.0)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _initialize_weights(self, module):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user