[EnCodec] Changes for 32kHz ckpt (#24296)
* [EnCodec] Changes for 32kHz ckpt * Update src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py * Update src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py
This commit is contained in:
@@ -90,6 +90,9 @@ class EncodecConfig(PretrainedConfig):
|
|||||||
Number of discret codes that make up VQVAE.
|
Number of discret codes that make up VQVAE.
|
||||||
codebook_dim (`int`, *optional*):
|
codebook_dim (`int`, *optional*):
|
||||||
Dimension of the codebook vectors. If not defined, uses `hidden_size`.
|
Dimension of the codebook vectors. If not defined, uses `hidden_size`.
|
||||||
|
use_conv_shortcut (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use a convolutional layer as the 'skip' connection in the `EncodecResnetBlock` block. If False,
|
||||||
|
an identity function will be used, giving a generic residual connection.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -131,6 +134,7 @@ class EncodecConfig(PretrainedConfig):
|
|||||||
trim_right_ratio=1.0,
|
trim_right_ratio=1.0,
|
||||||
codebook_size=1024,
|
codebook_size=1024,
|
||||||
codebook_dim=None,
|
codebook_dim=None,
|
||||||
|
use_conv_shortcut=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.target_bandwidths = target_bandwidths
|
self.target_bandwidths = target_bandwidths
|
||||||
@@ -155,6 +159,7 @@ class EncodecConfig(PretrainedConfig):
|
|||||||
self.trim_right_ratio = trim_right_ratio
|
self.trim_right_ratio = trim_right_ratio
|
||||||
self.codebook_size = codebook_size
|
self.codebook_size = codebook_size
|
||||||
self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
|
self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
|
||||||
|
self.use_conv_shortcut = use_conv_shortcut
|
||||||
|
|
||||||
if self.norm_type not in ["weight_norm", "time_group_norm"]:
|
if self.norm_type not in ["weight_norm", "time_group_norm"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from transformers import (
|
|||||||
|
|
||||||
# checkpoints downloaded from:
|
# checkpoints downloaded from:
|
||||||
# https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th
|
# https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th
|
||||||
|
# https://huggingface.co/facebook/musicgen-small/resolve/main/compression_state_dict.bin
|
||||||
# https://dl.fbaipublicfiles.com/encodec/v0/encodec_48khz-7e698e3e.th
|
# https://dl.fbaipublicfiles.com/encodec/v0/encodec_48khz-7e698e3e.th
|
||||||
|
|
||||||
|
|
||||||
@@ -206,7 +207,7 @@ def should_ignore(name, ignore_keys):
|
|||||||
def recursively_load_weights(orig_dict, hf_model, model_name):
|
def recursively_load_weights(orig_dict, hf_model, model_name):
|
||||||
unused_weights = []
|
unused_weights = []
|
||||||
|
|
||||||
if model_name == "encodec_24khz":
|
if model_name == "encodec_24khz" or "encodec_32khz":
|
||||||
MAPPING = MAPPING_24K
|
MAPPING = MAPPING_24K
|
||||||
elif model_name == "encodec_48khz":
|
elif model_name == "encodec_48khz":
|
||||||
MAPPING = MAPPING_48K
|
MAPPING = MAPPING_48K
|
||||||
@@ -292,6 +293,15 @@ def convert_checkpoint(
|
|||||||
|
|
||||||
if model_name == "encodec_24khz":
|
if model_name == "encodec_24khz":
|
||||||
pass # config is already correct
|
pass # config is already correct
|
||||||
|
elif model_name == "encodec_32khz":
|
||||||
|
config.upsampling_ratios = [8, 5, 4, 4]
|
||||||
|
config.target_bandwidths = [2.2]
|
||||||
|
config.num_filters = 64
|
||||||
|
config.sampling_rate = 32_000
|
||||||
|
config.codebook_size = 2048
|
||||||
|
config.use_causal_conv = False
|
||||||
|
config.normalize = False
|
||||||
|
config.use_conv_shortcut = False
|
||||||
elif model_name == "encodec_48khz":
|
elif model_name == "encodec_48khz":
|
||||||
config.upsampling_ratios = [8, 5, 4, 2]
|
config.upsampling_ratios = [8, 5, 4, 2]
|
||||||
config.target_bandwidths = [3.0, 6.0, 12.0, 24.0]
|
config.target_bandwidths = [3.0, 6.0, 12.0, 24.0]
|
||||||
@@ -316,6 +326,9 @@ def convert_checkpoint(
|
|||||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
original_checkpoint = torch.load(checkpoint_path)
|
original_checkpoint = torch.load(checkpoint_path)
|
||||||
|
if "best_state" in original_checkpoint:
|
||||||
|
# we might have a training state saved, in which case discard the yaml results and just retain the weights
|
||||||
|
original_checkpoint = original_checkpoint["best_state"]
|
||||||
recursively_load_weights(original_checkpoint, model, model_name)
|
recursively_load_weights(original_checkpoint, model, model_name)
|
||||||
model.save_pretrained(pytorch_dump_folder_path)
|
model.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
@@ -331,7 +344,7 @@ if __name__ == "__main__":
|
|||||||
"--model",
|
"--model",
|
||||||
default="encodec_24khz",
|
default="encodec_24khz",
|
||||||
type=str,
|
type=str,
|
||||||
help="The model to convert. Should be one of 'encodec_24khz', 'encodec_48khz'.",
|
help="The model to convert. Should be one of 'encodec_24khz', 'encodec_32khz', 'encodec_48khz'.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
||||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||||
|
|||||||
@@ -259,7 +259,10 @@ class EncodecResnetBlock(nn.Module):
|
|||||||
block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
|
block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
|
||||||
self.block = nn.ModuleList(block)
|
self.block = nn.ModuleList(block)
|
||||||
|
|
||||||
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
|
if config.use_conv_shortcut:
|
||||||
|
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
|
||||||
|
else:
|
||||||
|
self.shortcut = nn.Identity()
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|||||||
@@ -385,6 +385,11 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_identity_shortcut(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
config.use_conv_shortcut = False
|
||||||
|
self.model_tester.create_and_check_model_forward(config, inputs_dict)
|
||||||
|
|
||||||
|
|
||||||
def normalize(arr):
|
def normalize(arr):
|
||||||
norm = np.linalg.norm(arr)
|
norm = np.linalg.norm(arr)
|
||||||
|
|||||||
Reference in New Issue
Block a user