Detect and fix most _init_weights() issues - make it work for composite models (#37070)

* Update test_modeling_common.py

* Fix Llama and its modular children

* Update test_modeling_common.py

* qwen3

* first try at prioritizing models

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* test

* fix

* fix

* more models

* more

* more

* more

* smarter init for composite models!

* fix post rebase

* smol

* fix missing args

* more

* typo

* Super elegant and efficient init for submodels

* Update modeling_utils.py

* style

* last fixes

* cleanup

* finalize cleanup

* CIs

* improve docstring

* Update modeling_utils.py

* llama4

* style

* CIs

* style

* add dpt

* granite speech

* qwen 2.5 omni

* better fix

* Parse the config file instead

* CIs
This commit is contained in:
Cyril Vallez
2025-04-14 16:19:04 +02:00
committed by GitHub
parent 1897a02d83
commit 4e53840920
103 changed files with 1164 additions and 795 deletions

View File

@@ -85,6 +85,7 @@ class Phi4MultimodalModelTester:
intermediate_size=48,
depthwise_seperable_out_channel=128,
nemo_conv_channels=128,
initializer_range=1e-5,
),
vision_config=Phi4MultimodalVisionConfig(
num_hidden_layers=2,
@@ -92,6 +93,7 @@ class Phi4MultimodalModelTester:
intermediate_size=64,
num_attention_heads=8,
crop_size=16,
initializer_range=1e-5,
),
):
self.parent = parent

View File

@@ -503,6 +503,76 @@ class ModelTesterMixin:
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False"
)
def test_can_init_all_missing_weights(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# This is used to get the addition year of the model
filename = inspect.getfile(config.__class__)
# No easy way to get model addition date -> check copyright year on top of file
with open(filename) as file:
source_code = file.read()
addition_year = 0 # if we cannot find it, set it to 0 (i.e. oldest)
if match_object := re.search(r"^# Copyright (\d{4})", source_code, re.MULTILINE | re.IGNORECASE):
addition_year = int(match_object.group(1))
for model_class in self.all_model_classes:
# For now, skip everything older than 2025 and "important models" (too much models to patch otherwise)
# Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them
# TODO: relax this as we patch more and more models
if addition_year < 2025 and not model_class._supports_cache_class:
self.skipTest(reason=f"{model_class} is not a priorited model for now.")
# Monkey patch the method to add a seed (we do it on PreTrainedModel._initialize_weights, which wraps
# `_init_weights` so that it can add the seed for composite models as well)
original_initialize_weights = PreTrainedModel._initialize_weights
def seeded_initialize_weights(self, module):
set_seed(0)
original_initialize_weights(self, module)
PreTrainedModel._initialize_weights = seeded_initialize_weights
# First, initialize the model from config -> this ensure everything is correctly initialized, even if
# _init_weights() does not take all weights into account correctly
model_from_config = model_class(config)
# Here, passing an empty state dict will force all weights to be moved from meta to cpu, then be initialized
# by _init_weights()
model_from_pretrained = model_class.from_pretrained(None, config=config, state_dict={})
# Back to original method to avoid issues if running several other tests
PreTrainedModel._initialize_weights = original_initialize_weights
# First, check if any parameters are still on meta -> this is usually an issue with tied weights
params_on_meta = []
for k, v in model_from_pretrained.named_parameters():
if v.device.type == "meta":
params_on_meta.append(k)
self.assertTrue(
len(params_on_meta) == 0,
f"The following keys are still on the meta device, it probably comes from an issue in the tied weights:\n{params_on_meta}",
)
# Everything must be exactly the same as we set the same seed for each init
different_weights = []
for (k1, v1), (k2, v2) in zip(
model_from_config.state_dict().items(), model_from_pretrained.state_dict().items()
):
self.assertEqual(k1, k2, "The keys from each model should be the same")
# Since we added the seed, they should be exactly the same (i.e. using allclose maybe be wrong due
# to very low std in init function)
if not (v1 == v2).all():
different_weights.append(k1)
# Buffers that are initialized randomly are ignored as they are not initialized on meta device anyway
buffer_names = {name for name, _ in model_from_config.named_buffers()}
different_weights = [k for k in different_weights if k not in buffer_names]
self.assertTrue(
len(different_weights) == 0,
f"The following keys are not properly handled by `_init_weights()`:\n{different_weights}",
)
@slow
@require_accelerate
@mark.accelerate_tests