Detect and fix most _init_weights() issues - make it work for composite models (#37070)

* Update test_modeling_common.py

* Fix Llama and its modular children

* Update test_modeling_common.py

* qwen3

* first try at prioritizing models

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* test

* fix

* fix

* more models

* more

* more

* more

* smarter init for composite models!

* fix post rebase

* smol

* fix missing args

* more

* typo

* Super elegant and efficient init for submodels

* Update modeling_utils.py

* style

* last fixes

* cleanup

* finalize cleanup

* CIs

* improve docstring

* Update modeling_utils.py

* llama4

* style

* CIs

* style

* add dpt

* granite speech

* qwen 2.5 omni

* better fix

* Parse the config file instead

* CIs
This commit is contained in:
Cyril Vallez
2025-04-14 16:19:04 +02:00
committed by GitHub
parent 1897a02d83
commit 4e53840920
103 changed files with 1164 additions and 795 deletions

View File

@@ -2449,6 +2449,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
self._init_weights(module)
module._is_hf_initialized = True
@torch.no_grad()
def initialize_weights(self):
"""
This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models.
This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite
model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which
is extremely error prone and inefficient.
Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use
`torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as
`module.weight.data.zero_()`.
"""
if not hasattr(torch.nn.Module, "smart_apply"):
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
# to apply as we go down the graph
def smart_apply(self, fn):
for module in self.children():
# We found a sub-model: recursively dispatch its own init function now!
if hasattr(module, "_init_weights"):
module.smart_apply(module._initialize_weights)
else:
module.smart_apply(fn)
fn(self)
return self
torch.nn.Module.smart_apply = smart_apply
# Let the magic happen with this simple call
self.smart_apply(self._initialize_weights)
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
@@ -3074,7 +3105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if _init_weights:
# Initialize weights
self.apply(self._initialize_weights)
self.initialize_weights()
# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
@@ -5286,9 +5317,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
)
)
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
self.apply(self._initialize_weights)
self.initialize_weights()
else:
self.apply(self._initialize_weights)
self.initialize_weights()
def get_parameter_or_buffer(self, target: str):
"""