From d750eff62757a46160b6f73b95e8035c49c2971b Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 4 Sep 2023 17:09:26 +0100 Subject: [PATCH] [VITS] Fix init test (#25945) * [VITS] Fix init test * add flaky decorator * style * max attempts Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * style --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --- tests/models/vits/test_modeling_vits.py | 52 ++++++++++++++++--------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/tests/models/vits/test_modeling_vits.py b/tests/models/vits/test_modeling_vits.py index 81f8de8a6c..f28c4686ae 100644 --- a/tests/models/vits/test_modeling_vits.py +++ b/tests/models/vits/test_modeling_vits.py @@ -24,6 +24,7 @@ import numpy as np from transformers import PretrainedConfig, VitsConfig from transformers.testing_utils import ( + is_flaky, is_torch_available, require_torch, slow, @@ -80,6 +81,10 @@ class VitsModelTester: duration_predictor_filter_channels=16, prior_encoder_num_flows=2, upsample_initial_channel=16, + upsample_rates=[8, 2], + upsample_kernel_sizes=[16, 4], + resblock_kernel_sizes=[3, 7], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]], ): self.parent = parent self.batch_size = batch_size @@ -96,6 +101,10 @@ class VitsModelTester: self.duration_predictor_filter_channels = duration_predictor_filter_channels self.prior_encoder_num_flows = prior_encoder_num_flows self.upsample_initial_channel = upsample_initial_channel + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2) @@ -126,6 +135,10 @@ class VitsModelTester: duration_predictor_filter_channels=self.duration_predictor_filter_channels, posterior_encoder_num_wavenet_layers=self.num_hidden_layers, upsample_initial_channel=self.upsample_initial_channel, + upsample_rates=self.upsample_rates, + upsample_kernel_sizes=self.upsample_kernel_sizes, + resblock_kernel_sizes=self.resblock_kernel_sizes, + resblock_dilation_sizes=self.resblock_dilation_sizes, ) def create_and_check_model_forward(self, config, inputs_dict): @@ -135,7 +148,7 @@ class VitsModelTester: attention_mask = inputs_dict["attention_mask"] result = model(input_ids, attention_mask=attention_mask) - self.parent.assertEqual(result.waveform.shape, (self.batch_size, 11008)) + self.parent.assertEqual((self.batch_size, 624), result.waveform.shape) @require_torch @@ -168,30 +181,33 @@ class VitsModelTest(ModelTesterMixin, unittest.TestCase): def test_determinism(self): pass - # TODO: Fix me (ydshieh) - @unittest.skip("currently failing") + @is_flaky( + max_attempts=3, + description="Weight initialisation for the VITS conv layers sometimes exceeds the kaiming normal range", + ) def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + uniform_init_parms = [ + "emb_rel_k", + "emb_rel_v", + "conv_1", + "conv_2", + "conv_pre", + "conv_post", + "conv_proj", + "conv_dds", + "project", + "wavenet.in_layers", + "wavenet.res_skip_layers", + "upsampler", + "resblocks", + ] + configs_no_init = _config_zero_init(config) for model_class in self.all_model_classes: model = model_class(config=configs_no_init) for name, param in model.named_parameters(): - uniform_init_parms = [ - "emb_rel_k", - "emb_rel_v", - "conv_1", - "conv_2", - "conv_pre", - "conv_post", - "conv_proj", - "conv_dds", - "project", - "wavenet.in_layers", - "wavenet.res_skip_layers", - "upsampler", - "resblocks", - ] if param.requires_grad: if any(x in name for x in uniform_init_parms): self.assertTrue(