From 6a61e16626d26ea209e94a15b9ab3bfff7ba9bc5 Mon Sep 17 00:00:00 2001 From: BUI Van Tuan <37981884+bvantuan@users.noreply.github.com> Date: Mon, 28 Jul 2025 10:47:39 +0200 Subject: [PATCH] Fix missing initialization of `FastSpeech2Conformer` (#39689) * fix missing initialization of FastSpeech2Conformer * switch order and reactivate tests --------- Co-authored-by: Cyril Vallez --- .../modeling_fastspeech2_conformer.py | 10 +++++++--- .../test_modeling_fastspeech2_conformer.py | 3 --- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index 286c11b5e9..2b038a9339 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -966,14 +966,18 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.LayerNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=1.0 / math.sqrt(module.weight.size(1))) + if module.bias is not None: + module.bias.data.zero_() elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: key = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-key, b=key) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) elif isinstance(module, nn.Embedding): module.weight.data.normal_() if module.padding_idx is not None: diff --git a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py index 822bbb1131..4565414a36 100644 --- a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py +++ b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py @@ -28,7 +28,6 @@ from transformers.testing_utils import ( Expectations, require_g2p_en, require_torch, - require_torch_accelerator, slow, torch_device, ) @@ -123,7 +122,6 @@ class FastSpeech2ConformerModelTester: return config, inputs_dict -@require_torch_accelerator @require_torch class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (FastSpeech2ConformerModel,) if is_torch_available() else () @@ -561,7 +559,6 @@ class FastSpeech2ConformerWithHifiGanTester: return config, inputs_dict -@require_torch_accelerator @require_torch class FastSpeech2ConformerWithHifiGanTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (FastSpeech2ConformerWithHifiGan,) if is_torch_available() else ()