From 9dca0c91169b298ddf3a748d313e86ebb62cd4b8 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Fri, 11 Oct 2024 14:43:03 +0200 Subject: [PATCH] Fix DAC slow tests (#34088) * Fix DAC slow tests and fix decode * [run-slow] dac --- src/transformers/models/dac/modeling_dac.py | 4 ++-- tests/models/dac/test_modeling_dac.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py index 549f98b59d..f465ee77fa 100644 --- a/src/transformers/models/dac/modeling_dac.py +++ b/src/transformers/models/dac/modeling_dac.py @@ -641,14 +641,14 @@ class DacModel(DacPreTrainedModel): @replace_return_docstrings(output_type=DacDecoderOutput, config_class=_CONFIG_FOR_DOC) def decode( self, - quantized_representation: Optional[torch.Tensor], + quantized_representation: Optional[torch.Tensor] = None, audio_codes: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, ): """Decode given latent codes and return audio data Args: - quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`): + quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*): Quantized continuous representation of input. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*): The codebook indices for each codebook, representing the quantized discrete diff --git a/tests/models/dac/test_modeling_dac.py b/tests/models/dac/test_modeling_dac.py index e3b729d2f1..55a17ab1e0 100644 --- a/tests/models/dac/test_modeling_dac.py +++ b/tests/models/dac/test_modeling_dac.py @@ -458,9 +458,9 @@ class DacIntegrationTest(unittest.TestCase): expected_rmse = 0.0039 expected_encoder_output_dict = { - "quantized_representation": torch.tensor([0.9807, 2.8212, 5.2514, 2.7241, 1.0426]), + "quantized_representation": torch.tensor([0.6257, 3.1245, 5.2514, 2.3160, 1.5774]), "audio_codes": torch.tensor([919, 919, 234, 777, 234]), - "projected_latents": torch.tensor([-4.7822, -5.0046, -4.5574, -5.0363, -5.4271]), + "projected_latents": torch.tensor([-4.7841, -5.0063, -4.5595, -5.0372, -5.4280]), } librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") @@ -507,6 +507,11 @@ class DacIntegrationTest(unittest.TestCase): input_values_dec = model.decode(quantized_representation)[0] input_values_enc_dec = model(inputs["input_values"])[1] + input_values_from_codes = model.decode(audio_codes=encoder_outputs.audio_codes)[0] + + # make sure decode from audio codes and quantized values give more or less the same results + self.assertTrue(torch.allclose(input_values_from_codes, input_values_dec, atol=1e-5)) + # make sure forward and decode gives same result self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))