From 18e1a9c7195543ec7b7314fcd995bc7aad559e66 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Tue, 17 Sep 2024 08:05:21 +0200 Subject: [PATCH] Fix parametrization-based weight norm (#33275) * refactor weight_norm + propose uniformed solution to reconcile meta load_state_dict with classic loading * make style * fix sew * fix sew and sew_d tests --- src/transformers/modeling_utils.py | 33 +++++++++---- src/transformers/models/dac/modeling_dac.py | 46 ++++++++++--------- .../models/encodec/modeling_encodec.py | 13 +++++- .../modeling_fastspeech2_conformer.py | 18 ++++++-- .../seamless_m4t/modeling_seamless_m4t.py | 18 ++++++-- .../modeling_seamless_m4t_v2.py | 18 ++++++-- src/transformers/models/sew/modeling_sew.py | 8 +++- .../models/sew_d/modeling_sew_d.py | 8 +++- .../models/speecht5/modeling_speecht5.py | 18 ++++++-- .../models/univnet/modeling_univnet.py | 38 +++++++++++---- src/transformers/models/vits/modeling_vits.py | 14 ++++-- tests/models/sew/test_modeling_sew.py | 1 + tests/models/sew_d/test_modeling_sew_d.py | 1 + 13 files changed, 167 insertions(+), 67 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 359509f469..d406976663 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): def _load_state_dict_into_meta_model( model, state_dict, - loaded_state_dict_keys, # left for now but could be removed, see below start_prefix, expected_keys, device_map=None, @@ -847,8 +846,6 @@ def _load_state_dict_into_meta_model( # - deepspeed zero 3 support # - need to copy metadata if any - see _load_state_dict_into_model # - handling error_msgs - mimicking the error handling in module._load_from_state_dict() - # - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case - # they won't get loaded. error_msgs = [] @@ -868,6 +865,18 @@ def _load_state_dict_into_meta_model( # We add only the first key as an example new_key = key.replace("beta", "bias") renamed_beta[key] = new_key if not renamed_beta else renamed_beta + + # To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary. + if hasattr(nn.utils.parametrizations, "weight_norm"): + if "weight_g" in key: + new_key = key.replace("weight_g", "parametrizations.weight.original0") + if "weight_v" in key: + new_key = key.replace("weight_v", "parametrizations.weight.original1") + else: + if "parametrizations.weight.original0" in key: + new_key = key.replace("parametrizations.weight.original0", "weight_g") + if "parametrizations.weight.original1" in key: + new_key = key.replace("parametrizations.weight.original1", "weight_v") if new_key: old_keys.append(key) new_keys.append(new_key) @@ -884,8 +893,7 @@ def _load_state_dict_into_meta_model( is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") for param_name, param in state_dict.items(): - # First part of the test is always true as load_state_dict_keys always contains state_dict keys. - if param_name not in loaded_state_dict_keys or param_name not in expected_keys: + if param_name not in expected_keys: continue if param_name.startswith(start_prefix): @@ -4132,6 +4140,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return key.replace("beta", "bias") if "gamma" in key: return key.replace("gamma", "weight") + + # to avoid logging parametrized weight norm renaming + if hasattr(nn.utils.parametrizations, "weight_norm"): + if "weight_g" in key: + return key.replace("weight_g", "parametrizations.weight.original0") + if "weight_v" in key: + return key.replace("weight_v", "parametrizations.weight.original1") + else: + if "parametrizations.weight.original0" in key: + return key.replace("parametrizations.weight.original0", "weight_g") + if "parametrizations.weight.original1" in key: + return key.replace("parametrizations.weight.original1", "weight_v") return key original_loaded_keys = loaded_keys @@ -4376,7 +4396,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, state_dict, - loaded_keys, start_prefix, expected_keys, device_map=device_map, @@ -4453,7 +4472,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, state_dict, - loaded_keys, start_prefix, expected_keys, device_map=device_map, @@ -4609,7 +4627,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix error_msgs = _load_state_dict_into_meta_model( model, state_dict, - loaded_state_dict_keys, start_prefix, expected_keys=expected_keys, hf_quantizer=hf_quantizer, diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py index 5211685b02..549f98b59d 100644 --- a/src/transformers/models/dac/modeling_dac.py +++ b/src/transformers/models/dac/modeling_dac.py @@ -494,33 +494,37 @@ class DacPreTrainedModel(PreTrainedModel): nn.init.constant_(module.bias, 0) def apply_weight_norm(self): - for layer in self.quantizer.quantizers: - nn.utils.weight_norm(layer.in_proj) - nn.utils.weight_norm(layer.out_proj) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm - nn.utils.weight_norm(self.encoder.conv1) - nn.utils.weight_norm(self.encoder.conv2) + for layer in self.quantizer.quantizers: + weight_norm(layer.in_proj) + weight_norm(layer.out_proj) + + weight_norm(self.encoder.conv1) + weight_norm(self.encoder.conv2) for layer in self.encoder.block: - nn.utils.weight_norm(layer.conv1) - nn.utils.weight_norm(layer.res_unit1.conv1) - nn.utils.weight_norm(layer.res_unit1.conv2) - nn.utils.weight_norm(layer.res_unit2.conv1) - nn.utils.weight_norm(layer.res_unit2.conv2) - nn.utils.weight_norm(layer.res_unit3.conv1) - nn.utils.weight_norm(layer.res_unit3.conv2) + weight_norm(layer.conv1) + weight_norm(layer.res_unit1.conv1) + weight_norm(layer.res_unit1.conv2) + weight_norm(layer.res_unit2.conv1) + weight_norm(layer.res_unit2.conv2) + weight_norm(layer.res_unit3.conv1) + weight_norm(layer.res_unit3.conv2) - nn.utils.weight_norm(self.decoder.conv1) - nn.utils.weight_norm(self.decoder.conv2) + weight_norm(self.decoder.conv1) + weight_norm(self.decoder.conv2) for layer in self.decoder.block: - nn.utils.weight_norm(layer.conv_t1) - nn.utils.weight_norm(layer.res_unit1.conv1) - nn.utils.weight_norm(layer.res_unit1.conv2) - nn.utils.weight_norm(layer.res_unit2.conv1) - nn.utils.weight_norm(layer.res_unit2.conv2) - nn.utils.weight_norm(layer.res_unit3.conv1) - nn.utils.weight_norm(layer.res_unit3.conv2) + weight_norm(layer.conv_t1) + weight_norm(layer.res_unit1.conv1) + weight_norm(layer.res_unit1.conv2) + weight_norm(layer.res_unit2.conv1) + weight_norm(layer.res_unit2.conv2) + weight_norm(layer.res_unit3.conv1) + weight_norm(layer.res_unit3.conv2) def remove_weight_norm(self): for layer in self.quantizer.quantizers: diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index f325a6adbe..28ccb9513d 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -103,8 +103,12 @@ class EncodecConv1d(nn.Module): ) self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + if self.norm_type == "weight_norm": - self.conv = nn.utils.weight_norm(self.conv) + self.conv = weight_norm(self.conv) elif self.norm_type == "time_group_norm": self.norm = nn.GroupNorm(1, out_channels) @@ -186,8 +190,13 @@ class EncodecConvTranspose1d(nn.Module): ) self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + if config.norm_type == "weight_norm": - self.conv = nn.utils.weight_norm(self.conv) + self.conv = weight_norm(self.conv) elif config.norm_type == "time_group_norm": self.norm = nn.GroupNorm(1, out_channels) diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index e97e276b18..1e1900d38a 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -1416,10 +1416,14 @@ class HifiGanResidualBlock(nn.Module): return (kernel_size * dilation - dilation) // 2 def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + for layer in self.convs1: - nn.utils.weight_norm(layer) + weight_norm(layer) for layer in self.convs2: - nn.utils.weight_norm(layer) + weight_norm(layer) def remove_weight_norm(self): for layer in self.convs1: @@ -1493,12 +1497,16 @@ class FastSpeech2ConformerHifiGan(PreTrainedModel): module.bias.data.zero_() def apply_weight_norm(self): - nn.utils.weight_norm(self.conv_pre) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv_pre) for layer in self.upsampler: - nn.utils.weight_norm(layer) + weight_norm(layer) for layer in self.resblocks: layer.apply_weight_norm() - nn.utils.weight_norm(self.conv_post) + weight_norm(self.conv_post) def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv_pre) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index a79d1d4cf2..ba8230ec50 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -2361,10 +2361,14 @@ class HifiGanResidualBlock(nn.Module): return (kernel_size * dilation - dilation) // 2 def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + for layer in self.convs1: - nn.utils.weight_norm(layer) + weight_norm(layer) for layer in self.convs2: - nn.utils.weight_norm(layer) + weight_norm(layer) def remove_weight_norm(self): for layer in self.convs1: @@ -2633,12 +2637,16 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel): module.weight.data[module.padding_idx].zero_() def apply_weight_norm(self): - nn.utils.weight_norm(self.hifi_gan.conv_pre) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.hifi_gan.conv_pre) for layer in self.hifi_gan.upsampler: - nn.utils.weight_norm(layer) + weight_norm(layer) for layer in self.hifi_gan.resblocks: layer.apply_weight_norm() - nn.utils.weight_norm(self.hifi_gan.conv_post) + weight_norm(self.hifi_gan.conv_post) def remove_weight_norm(self): nn.utils.remove_weight_norm(self.hifi_gan.conv_pre) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index a53f544bb3..2d1fde8eed 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -2608,10 +2608,14 @@ class HifiGanResidualBlock(nn.Module): return (kernel_size * dilation - dilation) // 2 def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + for layer in self.convs1: - nn.utils.weight_norm(layer) + weight_norm(layer) for layer in self.convs2: - nn.utils.weight_norm(layer) + weight_norm(layer) def remove_weight_norm(self): for layer in self.convs1: @@ -2889,12 +2893,16 @@ class SeamlessM4Tv2CodeHifiGan(PreTrainedModel): # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.apply_weight_norm def apply_weight_norm(self): - nn.utils.weight_norm(self.hifi_gan.conv_pre) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.hifi_gan.conv_pre) for layer in self.hifi_gan.upsampler: - nn.utils.weight_norm(layer) + weight_norm(layer) for layer in self.hifi_gan.resblocks: layer.apply_weight_norm() - nn.utils.weight_norm(self.hifi_gan.conv_post) + weight_norm(self.hifi_gan.conv_post) # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.remove_weight_norm def remove_weight_norm(self): diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 191d7c2cd8..c9a3494b88 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -274,11 +274,15 @@ class SEWPositionalConvEmbedding(nn.Module): stride=config.squeeze_factor, ) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + if is_deepspeed_zero3_enabled(): import deepspeed with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): - self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + self.conv = weight_norm(self.conv, name="weight", dim=2) if hasattr(self.conv, "parametrizations"): weight_g = self.conv.parametrizations.weight.original0 weight_v = self.conv.parametrizations.weight.original1 @@ -288,7 +292,7 @@ class SEWPositionalConvEmbedding(nn.Module): deepspeed.zero.register_external_parameter(self, weight_v) deepspeed.zero.register_external_parameter(self, weight_g) else: - self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + self.conv = weight_norm(self.conv, name="weight", dim=2) self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings) self.activation = ACT2FN[config.feat_extract_activation] diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index a617f5e5d6..7f3db54def 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -349,11 +349,15 @@ class SEWDPositionalConvEmbedding(nn.Module): stride=config.squeeze_factor, ) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + if is_deepspeed_zero3_enabled(): import deepspeed with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): - self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + self.conv = weight_norm(self.conv, name="weight", dim=2) if hasattr(self.conv, "parametrizations"): weight_g = self.conv.parametrizations.weight.original0 weight_v = self.conv.parametrizations.weight.original1 @@ -363,7 +367,7 @@ class SEWDPositionalConvEmbedding(nn.Module): deepspeed.zero.register_external_parameter(self, weight_v) deepspeed.zero.register_external_parameter(self, weight_g) else: - self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + self.conv = weight_norm(self.conv, name="weight", dim=2) self.padding = SEWDSamePadLayer(config.num_conv_pos_embeddings) self.activation = ACT2FN[config.feat_extract_activation] diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index a69e9b56eb..790e6a74a4 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -3234,10 +3234,14 @@ class HifiGanResidualBlock(nn.Module): return (kernel_size * dilation - dilation) // 2 def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + for layer in self.convs1: - nn.utils.weight_norm(layer) + weight_norm(layer) for layer in self.convs2: - nn.utils.weight_norm(layer) + weight_norm(layer) def remove_weight_norm(self): for layer in self.convs1: @@ -3310,12 +3314,16 @@ class SpeechT5HifiGan(PreTrainedModel): module.bias.data.zero_() def apply_weight_norm(self): - nn.utils.weight_norm(self.conv_pre) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv_pre) for layer in self.upsampler: - nn.utils.weight_norm(layer) + weight_norm(layer) for layer in self.resblocks: layer.apply_weight_norm() - nn.utils.weight_norm(self.conv_post) + weight_norm(self.conv_post) def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv_pre) diff --git a/src/transformers/models/univnet/modeling_univnet.py b/src/transformers/models/univnet/modeling_univnet.py index 5b0c659c30..a780e54538 100644 --- a/src/transformers/models/univnet/modeling_univnet.py +++ b/src/transformers/models/univnet/modeling_univnet.py @@ -87,8 +87,12 @@ class UnivNetKernelPredictorResidualBlock(nn.Module): return hidden_states + residual def apply_weight_norm(self): - nn.utils.weight_norm(self.conv1) - nn.utils.weight_norm(self.conv2) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv1) + weight_norm(self.conv2) def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv1) @@ -197,11 +201,15 @@ class UnivNetKernelPredictor(nn.Module): return kernels, biases def apply_weight_norm(self): - nn.utils.weight_norm(self.input_conv) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.input_conv) for layer in self.resblocks: layer.apply_weight_norm() - nn.utils.weight_norm(self.kernel_conv) - nn.utils.weight_norm(self.bias_conv) + weight_norm(self.kernel_conv) + weight_norm(self.bias_conv) def remove_weight_norm(self): nn.utils.remove_weight_norm(self.input_conv) @@ -328,7 +336,11 @@ class UnivNetLvcResidualBlock(nn.Module): return output_hidden_states def apply_weight_norm(self): - nn.utils.weight_norm(self.conv) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv) def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv) @@ -398,7 +410,11 @@ class UnivNetLvcBlock(nn.Module): return hidden_states def apply_weight_norm(self): - nn.utils.weight_norm(self.convt_pre) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.convt_pre) self.kernel_predictor.apply_weight_norm() for layer in self.resblocks: layer.apply_weight_norm() @@ -619,10 +635,14 @@ class UnivNetModel(PreTrainedModel): module.bias.data.zero_() def apply_weight_norm(self): - nn.utils.weight_norm(self.conv_pre) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv_pre) for layer in self.resblocks: layer.apply_weight_norm() - nn.utils.weight_norm(self.conv_post) + weight_norm(self.conv_post) def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv_pre) diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index d8dffd4376..23bc8a72f8 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -461,10 +461,14 @@ class HifiGanResidualBlock(nn.Module): return (kernel_size * dilation - dilation) // 2 def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + for layer in self.convs1: - nn.utils.weight_norm(layer) + weight_norm(layer) for layer in self.convs2: - nn.utils.weight_norm(layer) + weight_norm(layer) def remove_weight_norm(self): for layer in self.convs1: @@ -521,8 +525,12 @@ class VitsHifiGan(nn.Module): self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1) def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + for layer in self.upsampler: - nn.utils.weight_norm(layer) + weight_norm(layer) for layer in self.resblocks: layer.apply_weight_norm() diff --git a/tests/models/sew/test_modeling_sew.py b/tests/models/sew/test_modeling_sew.py index 6b21c2e9f7..852f87c8f5 100644 --- a/tests/models/sew/test_modeling_sew.py +++ b/tests/models/sew/test_modeling_sew.py @@ -420,6 +420,7 @@ class SEWModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model = model_class(config=configs_no_init) for name, param in model.named_parameters(): uniform_init_parms = [ + "conv.parametrizations.weight", "conv.weight", "masked_spec_embed", "quantizer.weight_proj.weight", diff --git a/tests/models/sew_d/test_modeling_sew_d.py b/tests/models/sew_d/test_modeling_sew_d.py index b2efdccdf0..34374eb1e0 100644 --- a/tests/models/sew_d/test_modeling_sew_d.py +++ b/tests/models/sew_d/test_modeling_sew_d.py @@ -422,6 +422,7 @@ class SEWDModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model = model_class(config=configs_no_init) for name, param in model.named_parameters(): uniform_init_parms = [ + "conv.parametrizations.weight", "conv.weight", "masked_spec_embed", "quantizer.weight_proj.weight",