Fix missing initializations for models created in 2024 (#38987)
* fix GroundingDino * fix SuperGlue * fix GroundingDino * fix MambaModel * fix OmDetTurbo * fix SegGpt * fix Qwen2Audio * fix Mamba2 * fix DabDetr * fix Dac * fix FalconMamba * skip timm initialization * fix Encodec and MusicgenMelody * fix Musicgen * skip timm initialization test * fix OmDetTurbo * clean the code Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * add reviewed changes * add back timm * style * better check for parametrizations --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
@@ -310,12 +310,13 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# original_config.norm_type = "time_group_norm"
|
||||
for model_class in self.all_model_classes:
|
||||
torch.manual_seed(0)
|
||||
config = copy.deepcopy(original_config)
|
||||
config.chunk_length_s = None
|
||||
config.overlap = None
|
||||
config.sampling_rate = 10
|
||||
config.sampling_rate = 20
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
@@ -326,9 +327,9 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
hidden_states_no_chunk = model(**inputs)[1]
|
||||
|
||||
torch.manual_seed(0)
|
||||
config.chunk_length_s = 1
|
||||
config.chunk_length_s = 2
|
||||
config.overlap = 0
|
||||
config.sampling_rate = 10
|
||||
config.sampling_rate = 20
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
@@ -33,7 +33,7 @@ from transformers.testing_utils import (
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -359,9 +359,11 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
|
||||
def test_initialization(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.rescale_prenorm_residual = True
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if "dt_proj.bias" in name:
|
||||
dt = torch.exp(
|
||||
@@ -380,6 +382,19 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
if param.requires_grad:
|
||||
# check if it's a ones like
|
||||
torch.testing.assert_close(param.data, torch.ones_like(param.data), rtol=1e-5, atol=1e-5)
|
||||
else:
|
||||
if param.requires_grad:
|
||||
if (
|
||||
"mixer.conv1d.weight" in name
|
||||
or "mixer.dt_proj.weight" in name
|
||||
or "mixer.out_proj.weight" in name
|
||||
):
|
||||
continue
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@slow
|
||||
# Ignore copy
|
||||
|
||||
@@ -586,6 +586,8 @@ class GroundingDinoModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
|
||||
or "value_proj" in name
|
||||
or "output_proj" in name
|
||||
or "reference_points" in name
|
||||
or "vision_proj" in name
|
||||
or "text_proj" in name
|
||||
):
|
||||
continue
|
||||
self.assertIn(
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers.testing_utils import require_torch, require_torch_multi_gpu, s
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -326,9 +326,11 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
def test_initialization(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.rescale_prenorm_residual = True
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if "dt_proj.bias" in name:
|
||||
dt = torch.exp(
|
||||
@@ -347,6 +349,19 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
if param.requires_grad:
|
||||
# check if it's a ones like
|
||||
torch.testing.assert_close(param.data, torch.ones_like(param.data), rtol=1e-5, atol=1e-5)
|
||||
else:
|
||||
if param.requires_grad:
|
||||
if (
|
||||
"mixer.conv1d.weight" in name
|
||||
or "mixer.dt_proj.weight" in name
|
||||
or "mixer.out_proj.weight" in name
|
||||
):
|
||||
continue
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
||||
@@ -28,7 +29,7 @@ from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -276,14 +277,37 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
|
||||
def test_initialization(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.rescale_prenorm_residual = True
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if "D" in name:
|
||||
if "dt_proj.bias" in name:
|
||||
dt = torch.exp(
|
||||
torch.tensor([0, 1]) * (math.log(config.time_step_max) - math.log(config.time_step_min))
|
||||
+ math.log(config.time_step_min)
|
||||
).clamp(min=config.time_step_floor)
|
||||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||
if param.requires_grad:
|
||||
self.assertTrue(param.data.max().item() <= inv_dt[1])
|
||||
self.assertTrue(param.data.min().item() >= inv_dt[0])
|
||||
elif "A_log" in name:
|
||||
A = torch.arange(1, config.num_heads + 1)
|
||||
torch.testing.assert_close(param.data, torch.log(A), rtol=1e-5, atol=1e-5)
|
||||
elif "D" in name:
|
||||
if param.requires_grad:
|
||||
# check if it's a ones like
|
||||
torch.testing.assert_close(param.data, torch.ones_like(param.data), rtol=1e-5, atol=1e-5)
|
||||
else:
|
||||
if param.requires_grad:
|
||||
if "mixer.conv1d.weight" in name or "mixer.dt_bias" in name or "mixer.out_proj.weight" in name:
|
||||
continue
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Mamba 2 weights are not tied")
|
||||
def test_tied_weights_keys(self):
|
||||
|
||||
@@ -629,6 +629,7 @@ class OmDetTurboModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
or "decoder.channel_projection_layers" in name
|
||||
or "query_position_head" in name
|
||||
or "decoder.encoder_vision_features" in name
|
||||
or "language_backbone.text_projection" in name
|
||||
):
|
||||
continue
|
||||
self.assertIn(
|
||||
|
||||
@@ -153,10 +153,18 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TimmWrapper initialization is managed on the timm side")
|
||||
def test_can_init_all_missing_weights(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TimmWrapper initialization is managed on the timm side")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TimmWrapper initialization is managed on the timm side")
|
||||
def test_mismatched_shapes_have_properly_initialized_weights(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Need to use a timm model and there is no tiny model available.")
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
@@ -855,7 +855,7 @@ class ModelTesterMixin:
|
||||
# For now, skip everything older than 2025 and "important models" (too much models to patch otherwise)
|
||||
# Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them
|
||||
# TODO: relax this as we patch more and more models
|
||||
if addition_year < 2025 and not model_class._supports_cache_class:
|
||||
if addition_year < 2024 and not model_class._supports_cache_class:
|
||||
self.skipTest(reason=f"{model_class} is not a priorited model for now.")
|
||||
|
||||
# Monkey patch the method to add a seed (we do it on PreTrainedModel._initialize_weights, which wraps
|
||||
@@ -895,6 +895,11 @@ class ModelTesterMixin:
|
||||
model_from_config.state_dict().items(), model_from_pretrained.state_dict().items()
|
||||
):
|
||||
self.assertEqual(k1, k2, "The keys from each model should be the same")
|
||||
|
||||
# In case using torch.nn.utils.parametrizations on a module, we should skip the resulting keys
|
||||
if re.search(r"\.parametrizations\..*?\.original[01]", k1):
|
||||
continue
|
||||
|
||||
# Since we added the seed, they should be exactly the same (i.e. using allclose maybe be wrong due
|
||||
# to very low std in init function)
|
||||
if not (v1 == v2).all():
|
||||
|
||||
Reference in New Issue
Block a user