Detect and fix most _init_weights() issues - make it work for composite models (#37070)
* Update test_modeling_common.py * Fix Llama and its modular children * Update test_modeling_common.py * qwen3 * first try at prioritizing models * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * test * fix * fix * more models * more * more * more * smarter init for composite models! * fix post rebase * smol * fix missing args * more * typo * Super elegant and efficient init for submodels * Update modeling_utils.py * style * last fixes * cleanup * finalize cleanup * CIs * improve docstring * Update modeling_utils.py * llama4 * style * CIs * style * add dpt * granite speech * qwen 2.5 omni * better fix * Parse the config file instead * CIs
This commit is contained in:
@@ -2449,6 +2449,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
self._init_weights(module)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize_weights(self):
|
||||
"""
|
||||
This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models.
|
||||
This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
|
||||
module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite
|
||||
model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which
|
||||
is extremely error prone and inefficient.
|
||||
|
||||
Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use
|
||||
`torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as
|
||||
`module.weight.data.zero_()`.
|
||||
"""
|
||||
if not hasattr(torch.nn.Module, "smart_apply"):
|
||||
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
|
||||
# to apply as we go down the graph
|
||||
def smart_apply(self, fn):
|
||||
for module in self.children():
|
||||
# We found a sub-model: recursively dispatch its own init function now!
|
||||
if hasattr(module, "_init_weights"):
|
||||
module.smart_apply(module._initialize_weights)
|
||||
else:
|
||||
module.smart_apply(fn)
|
||||
fn(self)
|
||||
return self
|
||||
|
||||
torch.nn.Module.smart_apply = smart_apply
|
||||
|
||||
# Let the magic happen with this simple call
|
||||
self.smart_apply(self._initialize_weights)
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings.
|
||||
@@ -3074,7 +3105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
if _init_weights:
|
||||
# Initialize weights
|
||||
self.apply(self._initialize_weights)
|
||||
self.initialize_weights()
|
||||
|
||||
# Tie weights should be skipped when not initializing all weights
|
||||
# since from_pretrained(...) calls tie weights anyways
|
||||
@@ -5286,9 +5317,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
)
|
||||
)
|
||||
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
|
||||
self.apply(self._initialize_weights)
|
||||
self.initialize_weights()
|
||||
else:
|
||||
self.apply(self._initialize_weights)
|
||||
self.initialize_weights()
|
||||
|
||||
def get_parameter_or_buffer(self, target: str):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user