From 4124a09f8b3349f338917ad3282ca952bd15ec3a Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 15 Jun 2023 14:36:19 +0100 Subject: [PATCH] [EnCodec] Changes for 32kHz ckpt (#24296) * [EnCodec] Changes for 32kHz ckpt * Update src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py * Update src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py --- .../models/encodec/configuration_encodec.py | 5 +++++ .../convert_encodec_checkpoint_to_pytorch.py | 17 +++++++++++++++-- .../models/encodec/modeling_encodec.py | 5 ++++- tests/models/encodec/test_modeling_encodec.py | 5 +++++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/encodec/configuration_encodec.py b/src/transformers/models/encodec/configuration_encodec.py index 9ea2cfee94..e75711d926 100644 --- a/src/transformers/models/encodec/configuration_encodec.py +++ b/src/transformers/models/encodec/configuration_encodec.py @@ -90,6 +90,9 @@ class EncodecConfig(PretrainedConfig): Number of discret codes that make up VQVAE. codebook_dim (`int`, *optional*): Dimension of the codebook vectors. If not defined, uses `hidden_size`. + use_conv_shortcut (`bool`, *optional*, defaults to `True`): + Whether to use a convolutional layer as the 'skip' connection in the `EncodecResnetBlock` block. If False, + an identity function will be used, giving a generic residual connection. Example: @@ -131,6 +134,7 @@ class EncodecConfig(PretrainedConfig): trim_right_ratio=1.0, codebook_size=1024, codebook_dim=None, + use_conv_shortcut=True, **kwargs, ): self.target_bandwidths = target_bandwidths @@ -155,6 +159,7 @@ class EncodecConfig(PretrainedConfig): self.trim_right_ratio = trim_right_ratio self.codebook_size = codebook_size self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size + self.use_conv_shortcut = use_conv_shortcut if self.norm_type not in ["weight_norm", "time_group_norm"]: raise ValueError( diff --git a/src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py b/src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py index cd7ead3d72..3a16a4b7ba 100644 --- a/src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py +++ b/src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py @@ -28,6 +28,7 @@ from transformers import ( # checkpoints downloaded from: # https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th +# https://huggingface.co/facebook/musicgen-small/resolve/main/compression_state_dict.bin # https://dl.fbaipublicfiles.com/encodec/v0/encodec_48khz-7e698e3e.th @@ -206,7 +207,7 @@ def should_ignore(name, ignore_keys): def recursively_load_weights(orig_dict, hf_model, model_name): unused_weights = [] - if model_name == "encodec_24khz": + if model_name == "encodec_24khz" or "encodec_32khz": MAPPING = MAPPING_24K elif model_name == "encodec_48khz": MAPPING = MAPPING_48K @@ -292,6 +293,15 @@ def convert_checkpoint( if model_name == "encodec_24khz": pass # config is already correct + elif model_name == "encodec_32khz": + config.upsampling_ratios = [8, 5, 4, 4] + config.target_bandwidths = [2.2] + config.num_filters = 64 + config.sampling_rate = 32_000 + config.codebook_size = 2048 + config.use_causal_conv = False + config.normalize = False + config.use_conv_shortcut = False elif model_name == "encodec_48khz": config.upsampling_ratios = [8, 5, 4, 2] config.target_bandwidths = [3.0, 6.0, 12.0, 24.0] @@ -316,6 +326,9 @@ def convert_checkpoint( feature_extractor.save_pretrained(pytorch_dump_folder_path) original_checkpoint = torch.load(checkpoint_path) + if "best_state" in original_checkpoint: + # we might have a training state saved, in which case discard the yaml results and just retain the weights + original_checkpoint = original_checkpoint["best_state"] recursively_load_weights(original_checkpoint, model, model_name) model.save_pretrained(pytorch_dump_folder_path) @@ -331,7 +344,7 @@ if __name__ == "__main__": "--model", default="encodec_24khz", type=str, - help="The model to convert. Should be one of 'encodec_24khz', 'encodec_48khz'.", + help="The model to convert. Should be one of 'encodec_24khz', 'encodec_32khz', 'encodec_48khz'.", ) parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index ad1f6a0ee8..697fb3c94f 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -259,7 +259,10 @@ class EncodecResnetBlock(nn.Module): block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)] self.block = nn.ModuleList(block) - self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1) + if config.use_conv_shortcut: + self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1) + else: + self.shortcut = nn.Identity() def forward(self, hidden_states): residual = hidden_states diff --git a/tests/models/encodec/test_modeling_encodec.py b/tests/models/encodec/test_modeling_encodec.py index 23b2114a5d..398da6f5d0 100644 --- a/tests/models/encodec/test_modeling_encodec.py +++ b/tests/models/encodec/test_modeling_encodec.py @@ -385,6 +385,11 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + def test_identity_shortcut(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + config.use_conv_shortcut = False + self.model_tester.create_and_check_model_forward(config, inputs_dict) + def normalize(arr): norm = np.linalg.norm(arr)