Fix DAC slow tests (#34088)

* Fix DAC slow tests and fix decode

* [run-slow] dac
This commit is contained in:
Yoach Lacombe
2024-10-11 14:43:03 +02:00
committed by GitHub
parent f052e94bcc
commit 9dca0c9116
2 changed files with 9 additions and 4 deletions

View File

@@ -641,14 +641,14 @@ class DacModel(DacPreTrainedModel):
@replace_return_docstrings(output_type=DacDecoderOutput, config_class=_CONFIG_FOR_DOC)
def decode(
self,
quantized_representation: Optional[torch.Tensor],
quantized_representation: Optional[torch.Tensor] = None,
audio_codes: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
):
"""Decode given latent codes and return audio data
Args:
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`):
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
Quantized continuous representation of input.
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
The codebook indices for each codebook, representing the quantized discrete

View File

@@ -458,9 +458,9 @@ class DacIntegrationTest(unittest.TestCase):
expected_rmse = 0.0039
expected_encoder_output_dict = {
"quantized_representation": torch.tensor([0.9807, 2.8212, 5.2514, 2.7241, 1.0426]),
"quantized_representation": torch.tensor([0.6257, 3.1245, 5.2514, 2.3160, 1.5774]),
"audio_codes": torch.tensor([919, 919, 234, 777, 234]),
"projected_latents": torch.tensor([-4.7822, -5.0046, -4.5574, -5.0363, -5.4271]),
"projected_latents": torch.tensor([-4.7841, -5.0063, -4.5595, -5.0372, -5.4280]),
}
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
@@ -507,6 +507,11 @@ class DacIntegrationTest(unittest.TestCase):
input_values_dec = model.decode(quantized_representation)[0]
input_values_enc_dec = model(inputs["input_values"])[1]
input_values_from_codes = model.decode(audio_codes=encoder_outputs.audio_codes)[0]
# make sure decode from audio codes and quantized values give more or less the same results
self.assertTrue(torch.allclose(input_values_from_codes, input_values_dec, atol=1e-5))
# make sure forward and decode gives same result
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))