Fix missing initializations for models created in 2024 (#38987)
* fix GroundingDino * fix SuperGlue * fix GroundingDino * fix MambaModel * fix OmDetTurbo * fix SegGpt * fix Qwen2Audio * fix Mamba2 * fix DabDetr * fix Dac * fix FalconMamba * skip timm initialization * fix Encodec and MusicgenMelody * fix Musicgen * skip timm initialization test * fix OmDetTurbo * clean the code Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * add reviewed changes * add back timm * style * better check for parametrizations --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
@@ -24,7 +24,7 @@ from transformers.testing_utils import require_torch, require_torch_multi_gpu, s
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -326,9 +326,11 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
def test_initialization(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.rescale_prenorm_residual = True
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if "dt_proj.bias" in name:
|
||||
dt = torch.exp(
|
||||
@@ -347,6 +349,19 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
if param.requires_grad:
|
||||
# check if it's a ones like
|
||||
torch.testing.assert_close(param.data, torch.ones_like(param.data), rtol=1e-5, atol=1e-5)
|
||||
else:
|
||||
if param.requires_grad:
|
||||
if (
|
||||
"mixer.conv1d.weight" in name
|
||||
or "mixer.dt_proj.weight" in name
|
||||
or "mixer.out_proj.weight" in name
|
||||
):
|
||||
continue
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
|
||||
Reference in New Issue
Block a user