Fix parametrization-based weight norm (#33275)
* refactor weight_norm + propose uniformed solution to reconcile meta load_state_dict with classic loading * make style * fix sew * fix sew and sew_d tests
This commit is contained in:
@@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
|
|||||||
def _load_state_dict_into_meta_model(
|
def _load_state_dict_into_meta_model(
|
||||||
model,
|
model,
|
||||||
state_dict,
|
state_dict,
|
||||||
loaded_state_dict_keys, # left for now but could be removed, see below
|
|
||||||
start_prefix,
|
start_prefix,
|
||||||
expected_keys,
|
expected_keys,
|
||||||
device_map=None,
|
device_map=None,
|
||||||
@@ -847,8 +846,6 @@ def _load_state_dict_into_meta_model(
|
|||||||
# - deepspeed zero 3 support
|
# - deepspeed zero 3 support
|
||||||
# - need to copy metadata if any - see _load_state_dict_into_model
|
# - need to copy metadata if any - see _load_state_dict_into_model
|
||||||
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
|
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
|
||||||
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
|
|
||||||
# they won't get loaded.
|
|
||||||
|
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
|
|
||||||
@@ -868,6 +865,18 @@ def _load_state_dict_into_meta_model(
|
|||||||
# We add only the first key as an example
|
# We add only the first key as an example
|
||||||
new_key = key.replace("beta", "bias")
|
new_key = key.replace("beta", "bias")
|
||||||
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
|
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
|
||||||
|
|
||||||
|
# To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary.
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
if "weight_g" in key:
|
||||||
|
new_key = key.replace("weight_g", "parametrizations.weight.original0")
|
||||||
|
if "weight_v" in key:
|
||||||
|
new_key = key.replace("weight_v", "parametrizations.weight.original1")
|
||||||
|
else:
|
||||||
|
if "parametrizations.weight.original0" in key:
|
||||||
|
new_key = key.replace("parametrizations.weight.original0", "weight_g")
|
||||||
|
if "parametrizations.weight.original1" in key:
|
||||||
|
new_key = key.replace("parametrizations.weight.original1", "weight_v")
|
||||||
if new_key:
|
if new_key:
|
||||||
old_keys.append(key)
|
old_keys.append(key)
|
||||||
new_keys.append(new_key)
|
new_keys.append(new_key)
|
||||||
@@ -884,8 +893,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||||
|
|
||||||
for param_name, param in state_dict.items():
|
for param_name, param in state_dict.items():
|
||||||
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
|
if param_name not in expected_keys:
|
||||||
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if param_name.startswith(start_prefix):
|
if param_name.startswith(start_prefix):
|
||||||
@@ -4132,6 +4140,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
return key.replace("beta", "bias")
|
return key.replace("beta", "bias")
|
||||||
if "gamma" in key:
|
if "gamma" in key:
|
||||||
return key.replace("gamma", "weight")
|
return key.replace("gamma", "weight")
|
||||||
|
|
||||||
|
# to avoid logging parametrized weight norm renaming
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
if "weight_g" in key:
|
||||||
|
return key.replace("weight_g", "parametrizations.weight.original0")
|
||||||
|
if "weight_v" in key:
|
||||||
|
return key.replace("weight_v", "parametrizations.weight.original1")
|
||||||
|
else:
|
||||||
|
if "parametrizations.weight.original0" in key:
|
||||||
|
return key.replace("parametrizations.weight.original0", "weight_g")
|
||||||
|
if "parametrizations.weight.original1" in key:
|
||||||
|
return key.replace("parametrizations.weight.original1", "weight_v")
|
||||||
return key
|
return key
|
||||||
|
|
||||||
original_loaded_keys = loaded_keys
|
original_loaded_keys = loaded_keys
|
||||||
@@ -4376,7 +4396,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
state_dict,
|
state_dict,
|
||||||
loaded_keys,
|
|
||||||
start_prefix,
|
start_prefix,
|
||||||
expected_keys,
|
expected_keys,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
@@ -4453,7 +4472,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
state_dict,
|
state_dict,
|
||||||
loaded_keys,
|
|
||||||
start_prefix,
|
start_prefix,
|
||||||
expected_keys,
|
expected_keys,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
@@ -4609,7 +4627,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
error_msgs = _load_state_dict_into_meta_model(
|
error_msgs = _load_state_dict_into_meta_model(
|
||||||
model,
|
model,
|
||||||
state_dict,
|
state_dict,
|
||||||
loaded_state_dict_keys,
|
|
||||||
start_prefix,
|
start_prefix,
|
||||||
expected_keys=expected_keys,
|
expected_keys=expected_keys,
|
||||||
hf_quantizer=hf_quantizer,
|
hf_quantizer=hf_quantizer,
|
||||||
|
|||||||
@@ -494,33 +494,37 @@ class DacPreTrainedModel(PreTrainedModel):
|
|||||||
nn.init.constant_(module.bias, 0)
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
for layer in self.quantizer.quantizers:
|
weight_norm = nn.utils.weight_norm
|
||||||
nn.utils.weight_norm(layer.in_proj)
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
nn.utils.weight_norm(layer.out_proj)
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
nn.utils.weight_norm(self.encoder.conv1)
|
for layer in self.quantizer.quantizers:
|
||||||
nn.utils.weight_norm(self.encoder.conv2)
|
weight_norm(layer.in_proj)
|
||||||
|
weight_norm(layer.out_proj)
|
||||||
|
|
||||||
|
weight_norm(self.encoder.conv1)
|
||||||
|
weight_norm(self.encoder.conv2)
|
||||||
|
|
||||||
for layer in self.encoder.block:
|
for layer in self.encoder.block:
|
||||||
nn.utils.weight_norm(layer.conv1)
|
weight_norm(layer.conv1)
|
||||||
nn.utils.weight_norm(layer.res_unit1.conv1)
|
weight_norm(layer.res_unit1.conv1)
|
||||||
nn.utils.weight_norm(layer.res_unit1.conv2)
|
weight_norm(layer.res_unit1.conv2)
|
||||||
nn.utils.weight_norm(layer.res_unit2.conv1)
|
weight_norm(layer.res_unit2.conv1)
|
||||||
nn.utils.weight_norm(layer.res_unit2.conv2)
|
weight_norm(layer.res_unit2.conv2)
|
||||||
nn.utils.weight_norm(layer.res_unit3.conv1)
|
weight_norm(layer.res_unit3.conv1)
|
||||||
nn.utils.weight_norm(layer.res_unit3.conv2)
|
weight_norm(layer.res_unit3.conv2)
|
||||||
|
|
||||||
nn.utils.weight_norm(self.decoder.conv1)
|
weight_norm(self.decoder.conv1)
|
||||||
nn.utils.weight_norm(self.decoder.conv2)
|
weight_norm(self.decoder.conv2)
|
||||||
|
|
||||||
for layer in self.decoder.block:
|
for layer in self.decoder.block:
|
||||||
nn.utils.weight_norm(layer.conv_t1)
|
weight_norm(layer.conv_t1)
|
||||||
nn.utils.weight_norm(layer.res_unit1.conv1)
|
weight_norm(layer.res_unit1.conv1)
|
||||||
nn.utils.weight_norm(layer.res_unit1.conv2)
|
weight_norm(layer.res_unit1.conv2)
|
||||||
nn.utils.weight_norm(layer.res_unit2.conv1)
|
weight_norm(layer.res_unit2.conv1)
|
||||||
nn.utils.weight_norm(layer.res_unit2.conv2)
|
weight_norm(layer.res_unit2.conv2)
|
||||||
nn.utils.weight_norm(layer.res_unit3.conv1)
|
weight_norm(layer.res_unit3.conv1)
|
||||||
nn.utils.weight_norm(layer.res_unit3.conv2)
|
weight_norm(layer.res_unit3.conv2)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for layer in self.quantizer.quantizers:
|
for layer in self.quantizer.quantizers:
|
||||||
|
|||||||
@@ -103,8 +103,12 @@ class EncodecConv1d(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation)
|
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation)
|
||||||
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
if self.norm_type == "weight_norm":
|
if self.norm_type == "weight_norm":
|
||||||
self.conv = nn.utils.weight_norm(self.conv)
|
self.conv = weight_norm(self.conv)
|
||||||
elif self.norm_type == "time_group_norm":
|
elif self.norm_type == "time_group_norm":
|
||||||
self.norm = nn.GroupNorm(1, out_channels)
|
self.norm = nn.GroupNorm(1, out_channels)
|
||||||
|
|
||||||
@@ -186,8 +190,13 @@ class EncodecConvTranspose1d(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
||||||
|
|
||||||
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
if config.norm_type == "weight_norm":
|
if config.norm_type == "weight_norm":
|
||||||
self.conv = nn.utils.weight_norm(self.conv)
|
self.conv = weight_norm(self.conv)
|
||||||
elif config.norm_type == "time_group_norm":
|
elif config.norm_type == "time_group_norm":
|
||||||
self.norm = nn.GroupNorm(1, out_channels)
|
self.norm = nn.GroupNorm(1, out_channels)
|
||||||
|
|
||||||
|
|||||||
@@ -1416,10 +1416,14 @@ class HifiGanResidualBlock(nn.Module):
|
|||||||
return (kernel_size * dilation - dilation) // 2
|
return (kernel_size * dilation - dilation) // 2
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
for layer in self.convs1:
|
for layer in self.convs1:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
for layer in self.convs2:
|
for layer in self.convs2:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for layer in self.convs1:
|
for layer in self.convs1:
|
||||||
@@ -1493,12 +1497,16 @@ class FastSpeech2ConformerHifiGan(PreTrainedModel):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
nn.utils.weight_norm(self.conv_pre)
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
|
weight_norm(self.conv_pre)
|
||||||
for layer in self.upsampler:
|
for layer in self.upsampler:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
for layer in self.resblocks:
|
for layer in self.resblocks:
|
||||||
layer.apply_weight_norm()
|
layer.apply_weight_norm()
|
||||||
nn.utils.weight_norm(self.conv_post)
|
weight_norm(self.conv_post)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.conv_pre)
|
nn.utils.remove_weight_norm(self.conv_pre)
|
||||||
|
|||||||
@@ -2361,10 +2361,14 @@ class HifiGanResidualBlock(nn.Module):
|
|||||||
return (kernel_size * dilation - dilation) // 2
|
return (kernel_size * dilation - dilation) // 2
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
for layer in self.convs1:
|
for layer in self.convs1:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
for layer in self.convs2:
|
for layer in self.convs2:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for layer in self.convs1:
|
for layer in self.convs1:
|
||||||
@@ -2633,12 +2637,16 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
|
|||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
nn.utils.weight_norm(self.hifi_gan.conv_pre)
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
|
weight_norm(self.hifi_gan.conv_pre)
|
||||||
for layer in self.hifi_gan.upsampler:
|
for layer in self.hifi_gan.upsampler:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
for layer in self.hifi_gan.resblocks:
|
for layer in self.hifi_gan.resblocks:
|
||||||
layer.apply_weight_norm()
|
layer.apply_weight_norm()
|
||||||
nn.utils.weight_norm(self.hifi_gan.conv_post)
|
weight_norm(self.hifi_gan.conv_post)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.hifi_gan.conv_pre)
|
nn.utils.remove_weight_norm(self.hifi_gan.conv_pre)
|
||||||
|
|||||||
@@ -2608,10 +2608,14 @@ class HifiGanResidualBlock(nn.Module):
|
|||||||
return (kernel_size * dilation - dilation) // 2
|
return (kernel_size * dilation - dilation) // 2
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
for layer in self.convs1:
|
for layer in self.convs1:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
for layer in self.convs2:
|
for layer in self.convs2:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for layer in self.convs1:
|
for layer in self.convs1:
|
||||||
@@ -2889,12 +2893,16 @@ class SeamlessM4Tv2CodeHifiGan(PreTrainedModel):
|
|||||||
|
|
||||||
# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.apply_weight_norm
|
# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.apply_weight_norm
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
nn.utils.weight_norm(self.hifi_gan.conv_pre)
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
|
weight_norm(self.hifi_gan.conv_pre)
|
||||||
for layer in self.hifi_gan.upsampler:
|
for layer in self.hifi_gan.upsampler:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
for layer in self.hifi_gan.resblocks:
|
for layer in self.hifi_gan.resblocks:
|
||||||
layer.apply_weight_norm()
|
layer.apply_weight_norm()
|
||||||
nn.utils.weight_norm(self.hifi_gan.conv_post)
|
weight_norm(self.hifi_gan.conv_post)
|
||||||
|
|
||||||
# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.remove_weight_norm
|
# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.remove_weight_norm
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
|
|||||||
@@ -274,11 +274,15 @@ class SEWPositionalConvEmbedding(nn.Module):
|
|||||||
stride=config.squeeze_factor,
|
stride=config.squeeze_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||||
if hasattr(self.conv, "parametrizations"):
|
if hasattr(self.conv, "parametrizations"):
|
||||||
weight_g = self.conv.parametrizations.weight.original0
|
weight_g = self.conv.parametrizations.weight.original0
|
||||||
weight_v = self.conv.parametrizations.weight.original1
|
weight_v = self.conv.parametrizations.weight.original1
|
||||||
@@ -288,7 +292,7 @@ class SEWPositionalConvEmbedding(nn.Module):
|
|||||||
deepspeed.zero.register_external_parameter(self, weight_v)
|
deepspeed.zero.register_external_parameter(self, weight_v)
|
||||||
deepspeed.zero.register_external_parameter(self, weight_g)
|
deepspeed.zero.register_external_parameter(self, weight_g)
|
||||||
else:
|
else:
|
||||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||||
|
|
||||||
self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
|
self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
|
||||||
self.activation = ACT2FN[config.feat_extract_activation]
|
self.activation = ACT2FN[config.feat_extract_activation]
|
||||||
|
|||||||
@@ -349,11 +349,15 @@ class SEWDPositionalConvEmbedding(nn.Module):
|
|||||||
stride=config.squeeze_factor,
|
stride=config.squeeze_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||||
if hasattr(self.conv, "parametrizations"):
|
if hasattr(self.conv, "parametrizations"):
|
||||||
weight_g = self.conv.parametrizations.weight.original0
|
weight_g = self.conv.parametrizations.weight.original0
|
||||||
weight_v = self.conv.parametrizations.weight.original1
|
weight_v = self.conv.parametrizations.weight.original1
|
||||||
@@ -363,7 +367,7 @@ class SEWDPositionalConvEmbedding(nn.Module):
|
|||||||
deepspeed.zero.register_external_parameter(self, weight_v)
|
deepspeed.zero.register_external_parameter(self, weight_v)
|
||||||
deepspeed.zero.register_external_parameter(self, weight_g)
|
deepspeed.zero.register_external_parameter(self, weight_g)
|
||||||
else:
|
else:
|
||||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||||||
|
|
||||||
self.padding = SEWDSamePadLayer(config.num_conv_pos_embeddings)
|
self.padding = SEWDSamePadLayer(config.num_conv_pos_embeddings)
|
||||||
self.activation = ACT2FN[config.feat_extract_activation]
|
self.activation = ACT2FN[config.feat_extract_activation]
|
||||||
|
|||||||
@@ -3234,10 +3234,14 @@ class HifiGanResidualBlock(nn.Module):
|
|||||||
return (kernel_size * dilation - dilation) // 2
|
return (kernel_size * dilation - dilation) // 2
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
for layer in self.convs1:
|
for layer in self.convs1:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
for layer in self.convs2:
|
for layer in self.convs2:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for layer in self.convs1:
|
for layer in self.convs1:
|
||||||
@@ -3310,12 +3314,16 @@ class SpeechT5HifiGan(PreTrainedModel):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
nn.utils.weight_norm(self.conv_pre)
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
|
weight_norm(self.conv_pre)
|
||||||
for layer in self.upsampler:
|
for layer in self.upsampler:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
for layer in self.resblocks:
|
for layer in self.resblocks:
|
||||||
layer.apply_weight_norm()
|
layer.apply_weight_norm()
|
||||||
nn.utils.weight_norm(self.conv_post)
|
weight_norm(self.conv_post)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.conv_pre)
|
nn.utils.remove_weight_norm(self.conv_pre)
|
||||||
|
|||||||
@@ -87,8 +87,12 @@ class UnivNetKernelPredictorResidualBlock(nn.Module):
|
|||||||
return hidden_states + residual
|
return hidden_states + residual
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
nn.utils.weight_norm(self.conv1)
|
weight_norm = nn.utils.weight_norm
|
||||||
nn.utils.weight_norm(self.conv2)
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
|
weight_norm(self.conv1)
|
||||||
|
weight_norm(self.conv2)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.conv1)
|
nn.utils.remove_weight_norm(self.conv1)
|
||||||
@@ -197,11 +201,15 @@ class UnivNetKernelPredictor(nn.Module):
|
|||||||
return kernels, biases
|
return kernels, biases
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
nn.utils.weight_norm(self.input_conv)
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
|
weight_norm(self.input_conv)
|
||||||
for layer in self.resblocks:
|
for layer in self.resblocks:
|
||||||
layer.apply_weight_norm()
|
layer.apply_weight_norm()
|
||||||
nn.utils.weight_norm(self.kernel_conv)
|
weight_norm(self.kernel_conv)
|
||||||
nn.utils.weight_norm(self.bias_conv)
|
weight_norm(self.bias_conv)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.input_conv)
|
nn.utils.remove_weight_norm(self.input_conv)
|
||||||
@@ -328,7 +336,11 @@ class UnivNetLvcResidualBlock(nn.Module):
|
|||||||
return output_hidden_states
|
return output_hidden_states
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
nn.utils.weight_norm(self.conv)
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
|
weight_norm(self.conv)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.conv)
|
nn.utils.remove_weight_norm(self.conv)
|
||||||
@@ -398,7 +410,11 @@ class UnivNetLvcBlock(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
nn.utils.weight_norm(self.convt_pre)
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
|
weight_norm(self.convt_pre)
|
||||||
self.kernel_predictor.apply_weight_norm()
|
self.kernel_predictor.apply_weight_norm()
|
||||||
for layer in self.resblocks:
|
for layer in self.resblocks:
|
||||||
layer.apply_weight_norm()
|
layer.apply_weight_norm()
|
||||||
@@ -619,10 +635,14 @@ class UnivNetModel(PreTrainedModel):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
nn.utils.weight_norm(self.conv_pre)
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
|
weight_norm(self.conv_pre)
|
||||||
for layer in self.resblocks:
|
for layer in self.resblocks:
|
||||||
layer.apply_weight_norm()
|
layer.apply_weight_norm()
|
||||||
nn.utils.weight_norm(self.conv_post)
|
weight_norm(self.conv_post)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
nn.utils.remove_weight_norm(self.conv_pre)
|
nn.utils.remove_weight_norm(self.conv_pre)
|
||||||
|
|||||||
@@ -461,10 +461,14 @@ class HifiGanResidualBlock(nn.Module):
|
|||||||
return (kernel_size * dilation - dilation) // 2
|
return (kernel_size * dilation - dilation) // 2
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
for layer in self.convs1:
|
for layer in self.convs1:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
for layer in self.convs2:
|
for layer in self.convs2:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for layer in self.convs1:
|
for layer in self.convs1:
|
||||||
@@ -521,8 +525,12 @@ class VitsHifiGan(nn.Module):
|
|||||||
self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1)
|
self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1)
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
|
weight_norm = nn.utils.weight_norm
|
||||||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||||||
|
|
||||||
for layer in self.upsampler:
|
for layer in self.upsampler:
|
||||||
nn.utils.weight_norm(layer)
|
weight_norm(layer)
|
||||||
for layer in self.resblocks:
|
for layer in self.resblocks:
|
||||||
layer.apply_weight_norm()
|
layer.apply_weight_norm()
|
||||||
|
|
||||||
|
|||||||
@@ -420,6 +420,7 @@ class SEWModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
uniform_init_parms = [
|
uniform_init_parms = [
|
||||||
|
"conv.parametrizations.weight",
|
||||||
"conv.weight",
|
"conv.weight",
|
||||||
"masked_spec_embed",
|
"masked_spec_embed",
|
||||||
"quantizer.weight_proj.weight",
|
"quantizer.weight_proj.weight",
|
||||||
|
|||||||
@@ -422,6 +422,7 @@ class SEWDModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
uniform_init_parms = [
|
uniform_init_parms = [
|
||||||
|
"conv.parametrizations.weight",
|
||||||
"conv.weight",
|
"conv.weight",
|
||||||
"masked_spec_embed",
|
"masked_spec_embed",
|
||||||
"quantizer.weight_proj.weight",
|
"quantizer.weight_proj.weight",
|
||||||
|
|||||||
Reference in New Issue
Block a user