[EnCodec] Changes for 32kHz ckpt (#24296)
* [EnCodec] Changes for 32kHz ckpt * Update src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py * Update src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py
This commit is contained in:
@@ -90,6 +90,9 @@ class EncodecConfig(PretrainedConfig):
|
||||
Number of discret codes that make up VQVAE.
|
||||
codebook_dim (`int`, *optional*):
|
||||
Dimension of the codebook vectors. If not defined, uses `hidden_size`.
|
||||
use_conv_shortcut (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use a convolutional layer as the 'skip' connection in the `EncodecResnetBlock` block. If False,
|
||||
an identity function will be used, giving a generic residual connection.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -131,6 +134,7 @@ class EncodecConfig(PretrainedConfig):
|
||||
trim_right_ratio=1.0,
|
||||
codebook_size=1024,
|
||||
codebook_dim=None,
|
||||
use_conv_shortcut=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.target_bandwidths = target_bandwidths
|
||||
@@ -155,6 +159,7 @@ class EncodecConfig(PretrainedConfig):
|
||||
self.trim_right_ratio = trim_right_ratio
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
|
||||
self.use_conv_shortcut = use_conv_shortcut
|
||||
|
||||
if self.norm_type not in ["weight_norm", "time_group_norm"]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -28,6 +28,7 @@ from transformers import (
|
||||
|
||||
# checkpoints downloaded from:
|
||||
# https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th
|
||||
# https://huggingface.co/facebook/musicgen-small/resolve/main/compression_state_dict.bin
|
||||
# https://dl.fbaipublicfiles.com/encodec/v0/encodec_48khz-7e698e3e.th
|
||||
|
||||
|
||||
@@ -206,7 +207,7 @@ def should_ignore(name, ignore_keys):
|
||||
def recursively_load_weights(orig_dict, hf_model, model_name):
|
||||
unused_weights = []
|
||||
|
||||
if model_name == "encodec_24khz":
|
||||
if model_name == "encodec_24khz" or "encodec_32khz":
|
||||
MAPPING = MAPPING_24K
|
||||
elif model_name == "encodec_48khz":
|
||||
MAPPING = MAPPING_48K
|
||||
@@ -292,6 +293,15 @@ def convert_checkpoint(
|
||||
|
||||
if model_name == "encodec_24khz":
|
||||
pass # config is already correct
|
||||
elif model_name == "encodec_32khz":
|
||||
config.upsampling_ratios = [8, 5, 4, 4]
|
||||
config.target_bandwidths = [2.2]
|
||||
config.num_filters = 64
|
||||
config.sampling_rate = 32_000
|
||||
config.codebook_size = 2048
|
||||
config.use_causal_conv = False
|
||||
config.normalize = False
|
||||
config.use_conv_shortcut = False
|
||||
elif model_name == "encodec_48khz":
|
||||
config.upsampling_ratios = [8, 5, 4, 2]
|
||||
config.target_bandwidths = [3.0, 6.0, 12.0, 24.0]
|
||||
@@ -316,6 +326,9 @@ def convert_checkpoint(
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
original_checkpoint = torch.load(checkpoint_path)
|
||||
if "best_state" in original_checkpoint:
|
||||
# we might have a training state saved, in which case discard the yaml results and just retain the weights
|
||||
original_checkpoint = original_checkpoint["best_state"]
|
||||
recursively_load_weights(original_checkpoint, model, model_name)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
@@ -331,7 +344,7 @@ if __name__ == "__main__":
|
||||
"--model",
|
||||
default="encodec_24khz",
|
||||
type=str,
|
||||
help="The model to convert. Should be one of 'encodec_24khz', 'encodec_48khz'.",
|
||||
help="The model to convert. Should be one of 'encodec_24khz', 'encodec_32khz', 'encodec_48khz'.",
|
||||
)
|
||||
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
|
||||
@@ -259,7 +259,10 @@ class EncodecResnetBlock(nn.Module):
|
||||
block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
|
||||
self.block = nn.ModuleList(block)
|
||||
|
||||
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
|
||||
if config.use_conv_shortcut:
|
||||
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
|
||||
else:
|
||||
self.shortcut = nn.Identity()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
|
||||
@@ -385,6 +385,11 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_identity_shortcut(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
config.use_conv_shortcut = False
|
||||
self.model_tester.create_and_check_model_forward(config, inputs_dict)
|
||||
|
||||
|
||||
def normalize(arr):
|
||||
norm = np.linalg.norm(arr)
|
||||
|
||||
Reference in New Issue
Block a user