Detect and fix most _init_weights() issues - make it work for composite models (#37070)
* Update test_modeling_common.py * Fix Llama and its modular children * Update test_modeling_common.py * qwen3 * first try at prioritizing models * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * test * fix * fix * more models * more * more * more * smarter init for composite models! * fix post rebase * smol * fix missing args * more * typo * Super elegant and efficient init for submodels * Update modeling_utils.py * style * last fixes * cleanup * finalize cleanup * CIs * improve docstring * Update modeling_utils.py * llama4 * style * CIs * style * add dpt * granite speech * qwen 2.5 omni * better fix * Parse the config file instead * CIs
This commit is contained in:
@@ -85,6 +85,7 @@ class Phi4MultimodalModelTester:
|
||||
intermediate_size=48,
|
||||
depthwise_seperable_out_channel=128,
|
||||
nemo_conv_channels=128,
|
||||
initializer_range=1e-5,
|
||||
),
|
||||
vision_config=Phi4MultimodalVisionConfig(
|
||||
num_hidden_layers=2,
|
||||
@@ -92,6 +93,7 @@ class Phi4MultimodalModelTester:
|
||||
intermediate_size=64,
|
||||
num_attention_heads=8,
|
||||
crop_size=16,
|
||||
initializer_range=1e-5,
|
||||
),
|
||||
):
|
||||
self.parent = parent
|
||||
|
||||
@@ -503,6 +503,76 @@ class ModelTesterMixin:
|
||||
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False"
|
||||
)
|
||||
|
||||
def test_can_init_all_missing_weights(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# This is used to get the addition year of the model
|
||||
filename = inspect.getfile(config.__class__)
|
||||
# No easy way to get model addition date -> check copyright year on top of file
|
||||
with open(filename) as file:
|
||||
source_code = file.read()
|
||||
addition_year = 0 # if we cannot find it, set it to 0 (i.e. oldest)
|
||||
if match_object := re.search(r"^# Copyright (\d{4})", source_code, re.MULTILINE | re.IGNORECASE):
|
||||
addition_year = int(match_object.group(1))
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# For now, skip everything older than 2025 and "important models" (too much models to patch otherwise)
|
||||
# Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them
|
||||
# TODO: relax this as we patch more and more models
|
||||
if addition_year < 2025 and not model_class._supports_cache_class:
|
||||
self.skipTest(reason=f"{model_class} is not a priorited model for now.")
|
||||
|
||||
# Monkey patch the method to add a seed (we do it on PreTrainedModel._initialize_weights, which wraps
|
||||
# `_init_weights` so that it can add the seed for composite models as well)
|
||||
original_initialize_weights = PreTrainedModel._initialize_weights
|
||||
|
||||
def seeded_initialize_weights(self, module):
|
||||
set_seed(0)
|
||||
original_initialize_weights(self, module)
|
||||
|
||||
PreTrainedModel._initialize_weights = seeded_initialize_weights
|
||||
|
||||
# First, initialize the model from config -> this ensure everything is correctly initialized, even if
|
||||
# _init_weights() does not take all weights into account correctly
|
||||
model_from_config = model_class(config)
|
||||
# Here, passing an empty state dict will force all weights to be moved from meta to cpu, then be initialized
|
||||
# by _init_weights()
|
||||
model_from_pretrained = model_class.from_pretrained(None, config=config, state_dict={})
|
||||
|
||||
# Back to original method to avoid issues if running several other tests
|
||||
PreTrainedModel._initialize_weights = original_initialize_weights
|
||||
|
||||
# First, check if any parameters are still on meta -> this is usually an issue with tied weights
|
||||
params_on_meta = []
|
||||
for k, v in model_from_pretrained.named_parameters():
|
||||
if v.device.type == "meta":
|
||||
params_on_meta.append(k)
|
||||
|
||||
self.assertTrue(
|
||||
len(params_on_meta) == 0,
|
||||
f"The following keys are still on the meta device, it probably comes from an issue in the tied weights:\n{params_on_meta}",
|
||||
)
|
||||
|
||||
# Everything must be exactly the same as we set the same seed for each init
|
||||
different_weights = []
|
||||
for (k1, v1), (k2, v2) in zip(
|
||||
model_from_config.state_dict().items(), model_from_pretrained.state_dict().items()
|
||||
):
|
||||
self.assertEqual(k1, k2, "The keys from each model should be the same")
|
||||
# Since we added the seed, they should be exactly the same (i.e. using allclose maybe be wrong due
|
||||
# to very low std in init function)
|
||||
if not (v1 == v2).all():
|
||||
different_weights.append(k1)
|
||||
|
||||
# Buffers that are initialized randomly are ignored as they are not initialized on meta device anyway
|
||||
buffer_names = {name for name, _ in model_from_config.named_buffers()}
|
||||
different_weights = [k for k in different_weights if k not in buffer_names]
|
||||
|
||||
self.assertTrue(
|
||||
len(different_weights) == 0,
|
||||
f"The following keys are not properly handled by `_init_weights()`:\n{different_weights}",
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
|
||||
Reference in New Issue
Block a user