diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index 286c11b5e9..2b038a9339 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -966,14 +966,18 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" - if isinstance(module, (nn.LayerNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=1.0 / math.sqrt(module.weight.size(1))) + if module.bias is not None: + module.bias.data.zero_() elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: key = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-key, b=key) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) elif isinstance(module, nn.Embedding): module.weight.data.normal_() if module.padding_idx is not None: diff --git a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py index 822bbb1131..4565414a36 100644 --- a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py +++ b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py @@ -28,7 +28,6 @@ from transformers.testing_utils import ( Expectations, require_g2p_en, require_torch, - require_torch_accelerator, slow, torch_device, ) @@ -123,7 +122,6 @@ class FastSpeech2ConformerModelTester: return config, inputs_dict -@require_torch_accelerator @require_torch class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (FastSpeech2ConformerModel,) if is_torch_available() else () @@ -561,7 +559,6 @@ class FastSpeech2ConformerWithHifiGanTester: return config, inputs_dict -@require_torch_accelerator @require_torch class FastSpeech2ConformerWithHifiGanTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (FastSpeech2ConformerWithHifiGan,) if is_torch_available() else ()