Fix parametrization-based weight norm (#33275)

* refactor weight_norm + propose uniformed solution to reconcile meta load_state_dict with classic loading

* make style

* fix sew

* fix sew and sew_d tests
This commit is contained in:
Yoach Lacombe
2024-09-17 08:05:21 +02:00
committed by GitHub
parent 9f196ef2e0
commit 18e1a9c719
13 changed files with 167 additions and 67 deletions

View File

@@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
def _load_state_dict_into_meta_model(
model,
state_dict,
loaded_state_dict_keys, # left for now but could be removed, see below
start_prefix,
expected_keys,
device_map=None,
@@ -847,8 +846,6 @@ def _load_state_dict_into_meta_model(
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
error_msgs = []
@@ -868,6 +865,18 @@ def _load_state_dict_into_meta_model(
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
# To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary.
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
new_key = key.replace("weight_g", "parametrizations.weight.original0")
if "weight_v" in key:
new_key = key.replace("weight_v", "parametrizations.weight.original1")
else:
if "parametrizations.weight.original0" in key:
new_key = key.replace("parametrizations.weight.original0", "weight_g")
if "parametrizations.weight.original1" in key:
new_key = key.replace("parametrizations.weight.original1", "weight_v")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
@@ -884,8 +893,7 @@ def _load_state_dict_into_meta_model(
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
for param_name, param in state_dict.items():
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
if param_name not in expected_keys:
continue
if param_name.startswith(start_prefix):
@@ -4132,6 +4140,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
# to avoid logging parametrized weight norm renaming
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
return key.replace("weight_g", "parametrizations.weight.original0")
if "weight_v" in key:
return key.replace("weight_v", "parametrizations.weight.original1")
else:
if "parametrizations.weight.original0" in key:
return key.replace("parametrizations.weight.original0", "weight_g")
if "parametrizations.weight.original1" in key:
return key.replace("parametrizations.weight.original1", "weight_v")
return key
original_loaded_keys = loaded_keys
@@ -4376,7 +4396,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
@@ -4453,7 +4472,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
@@ -4609,7 +4627,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
error_msgs = _load_state_dict_into_meta_model(
model,
state_dict,
loaded_state_dict_keys,
start_prefix,
expected_keys=expected_keys,
hf_quantizer=hf_quantizer,

View File

@@ -494,33 +494,37 @@ class DacPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.bias, 0)
def apply_weight_norm(self):
for layer in self.quantizer.quantizers:
nn.utils.weight_norm(layer.in_proj)
nn.utils.weight_norm(layer.out_proj)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
nn.utils.weight_norm(self.encoder.conv1)
nn.utils.weight_norm(self.encoder.conv2)
for layer in self.quantizer.quantizers:
weight_norm(layer.in_proj)
weight_norm(layer.out_proj)
weight_norm(self.encoder.conv1)
weight_norm(self.encoder.conv2)
for layer in self.encoder.block:
nn.utils.weight_norm(layer.conv1)
nn.utils.weight_norm(layer.res_unit1.conv1)
nn.utils.weight_norm(layer.res_unit1.conv2)
nn.utils.weight_norm(layer.res_unit2.conv1)
nn.utils.weight_norm(layer.res_unit2.conv2)
nn.utils.weight_norm(layer.res_unit3.conv1)
nn.utils.weight_norm(layer.res_unit3.conv2)
weight_norm(layer.conv1)
weight_norm(layer.res_unit1.conv1)
weight_norm(layer.res_unit1.conv2)
weight_norm(layer.res_unit2.conv1)
weight_norm(layer.res_unit2.conv2)
weight_norm(layer.res_unit3.conv1)
weight_norm(layer.res_unit3.conv2)
nn.utils.weight_norm(self.decoder.conv1)
nn.utils.weight_norm(self.decoder.conv2)
weight_norm(self.decoder.conv1)
weight_norm(self.decoder.conv2)
for layer in self.decoder.block:
nn.utils.weight_norm(layer.conv_t1)
nn.utils.weight_norm(layer.res_unit1.conv1)
nn.utils.weight_norm(layer.res_unit1.conv2)
nn.utils.weight_norm(layer.res_unit2.conv1)
nn.utils.weight_norm(layer.res_unit2.conv2)
nn.utils.weight_norm(layer.res_unit3.conv1)
nn.utils.weight_norm(layer.res_unit3.conv2)
weight_norm(layer.conv_t1)
weight_norm(layer.res_unit1.conv1)
weight_norm(layer.res_unit1.conv2)
weight_norm(layer.res_unit2.conv1)
weight_norm(layer.res_unit2.conv2)
weight_norm(layer.res_unit3.conv1)
weight_norm(layer.res_unit3.conv2)
def remove_weight_norm(self):
for layer in self.quantizer.quantizers:

View File

@@ -103,8 +103,12 @@ class EncodecConv1d(nn.Module):
)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if self.norm_type == "weight_norm":
self.conv = nn.utils.weight_norm(self.conv)
self.conv = weight_norm(self.conv)
elif self.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels)
@@ -186,8 +190,13 @@ class EncodecConvTranspose1d(nn.Module):
)
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if config.norm_type == "weight_norm":
self.conv = nn.utils.weight_norm(self.conv)
self.conv = weight_norm(self.conv)
elif config.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels)

View File

@@ -1416,10 +1416,14 @@ class HifiGanResidualBlock(nn.Module):
return (kernel_size * dilation - dilation) // 2
def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
for layer in self.convs1:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.convs2:
nn.utils.weight_norm(layer)
weight_norm(layer)
def remove_weight_norm(self):
for layer in self.convs1:
@@ -1493,12 +1497,16 @@ class FastSpeech2ConformerHifiGan(PreTrainedModel):
module.bias.data.zero_()
def apply_weight_norm(self):
nn.utils.weight_norm(self.conv_pre)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
weight_norm(self.conv_pre)
for layer in self.upsampler:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.resblocks:
layer.apply_weight_norm()
nn.utils.weight_norm(self.conv_post)
weight_norm(self.conv_post)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)

View File

@@ -2361,10 +2361,14 @@ class HifiGanResidualBlock(nn.Module):
return (kernel_size * dilation - dilation) // 2
def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
for layer in self.convs1:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.convs2:
nn.utils.weight_norm(layer)
weight_norm(layer)
def remove_weight_norm(self):
for layer in self.convs1:
@@ -2633,12 +2637,16 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
module.weight.data[module.padding_idx].zero_()
def apply_weight_norm(self):
nn.utils.weight_norm(self.hifi_gan.conv_pre)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
weight_norm(self.hifi_gan.conv_pre)
for layer in self.hifi_gan.upsampler:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.hifi_gan.resblocks:
layer.apply_weight_norm()
nn.utils.weight_norm(self.hifi_gan.conv_post)
weight_norm(self.hifi_gan.conv_post)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.hifi_gan.conv_pre)

View File

@@ -2608,10 +2608,14 @@ class HifiGanResidualBlock(nn.Module):
return (kernel_size * dilation - dilation) // 2
def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
for layer in self.convs1:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.convs2:
nn.utils.weight_norm(layer)
weight_norm(layer)
def remove_weight_norm(self):
for layer in self.convs1:
@@ -2889,12 +2893,16 @@ class SeamlessM4Tv2CodeHifiGan(PreTrainedModel):
# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.apply_weight_norm
def apply_weight_norm(self):
nn.utils.weight_norm(self.hifi_gan.conv_pre)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
weight_norm(self.hifi_gan.conv_pre)
for layer in self.hifi_gan.upsampler:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.hifi_gan.resblocks:
layer.apply_weight_norm()
nn.utils.weight_norm(self.hifi_gan.conv_post)
weight_norm(self.hifi_gan.conv_post)
# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.remove_weight_norm
def remove_weight_norm(self):

View File

@@ -274,11 +274,15 @@ class SEWPositionalConvEmbedding(nn.Module):
stride=config.squeeze_factor,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
@@ -288,7 +292,7 @@ class SEWPositionalConvEmbedding(nn.Module):
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]

View File

@@ -349,11 +349,15 @@ class SEWDPositionalConvEmbedding(nn.Module):
stride=config.squeeze_factor,
)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
@@ -363,7 +367,7 @@ class SEWDPositionalConvEmbedding(nn.Module):
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = SEWDSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]

View File

@@ -3234,10 +3234,14 @@ class HifiGanResidualBlock(nn.Module):
return (kernel_size * dilation - dilation) // 2
def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
for layer in self.convs1:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.convs2:
nn.utils.weight_norm(layer)
weight_norm(layer)
def remove_weight_norm(self):
for layer in self.convs1:
@@ -3310,12 +3314,16 @@ class SpeechT5HifiGan(PreTrainedModel):
module.bias.data.zero_()
def apply_weight_norm(self):
nn.utils.weight_norm(self.conv_pre)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
weight_norm(self.conv_pre)
for layer in self.upsampler:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.resblocks:
layer.apply_weight_norm()
nn.utils.weight_norm(self.conv_post)
weight_norm(self.conv_post)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)

View File

@@ -87,8 +87,12 @@ class UnivNetKernelPredictorResidualBlock(nn.Module):
return hidden_states + residual
def apply_weight_norm(self):
nn.utils.weight_norm(self.conv1)
nn.utils.weight_norm(self.conv2)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
weight_norm(self.conv1)
weight_norm(self.conv2)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv1)
@@ -197,11 +201,15 @@ class UnivNetKernelPredictor(nn.Module):
return kernels, biases
def apply_weight_norm(self):
nn.utils.weight_norm(self.input_conv)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
weight_norm(self.input_conv)
for layer in self.resblocks:
layer.apply_weight_norm()
nn.utils.weight_norm(self.kernel_conv)
nn.utils.weight_norm(self.bias_conv)
weight_norm(self.kernel_conv)
weight_norm(self.bias_conv)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv)
@@ -328,7 +336,11 @@ class UnivNetLvcResidualBlock(nn.Module):
return output_hidden_states
def apply_weight_norm(self):
nn.utils.weight_norm(self.conv)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
weight_norm(self.conv)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv)
@@ -398,7 +410,11 @@ class UnivNetLvcBlock(nn.Module):
return hidden_states
def apply_weight_norm(self):
nn.utils.weight_norm(self.convt_pre)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
weight_norm(self.convt_pre)
self.kernel_predictor.apply_weight_norm()
for layer in self.resblocks:
layer.apply_weight_norm()
@@ -619,10 +635,14 @@ class UnivNetModel(PreTrainedModel):
module.bias.data.zero_()
def apply_weight_norm(self):
nn.utils.weight_norm(self.conv_pre)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
weight_norm(self.conv_pre)
for layer in self.resblocks:
layer.apply_weight_norm()
nn.utils.weight_norm(self.conv_post)
weight_norm(self.conv_post)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)

View File

@@ -461,10 +461,14 @@ class HifiGanResidualBlock(nn.Module):
return (kernel_size * dilation - dilation) // 2
def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
for layer in self.convs1:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.convs2:
nn.utils.weight_norm(layer)
weight_norm(layer)
def remove_weight_norm(self):
for layer in self.convs1:
@@ -521,8 +525,12 @@ class VitsHifiGan(nn.Module):
self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1)
def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
for layer in self.upsampler:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.resblocks:
layer.apply_weight_norm()

View File

@@ -420,6 +420,7 @@ class SEWModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.parametrizations.weight",
"conv.weight",
"masked_spec_embed",
"quantizer.weight_proj.weight",

View File

@@ -422,6 +422,7 @@ class SEWDModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = [
"conv.parametrizations.weight",
"conv.weight",
"masked_spec_embed",
"quantizer.weight_proj.weight",