Fix DAC slow tests (#34088)
* Fix DAC slow tests and fix decode * [run-slow] dac
This commit is contained in:
@@ -641,14 +641,14 @@ class DacModel(DacPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=DacDecoderOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def decode(
|
||||
self,
|
||||
quantized_representation: Optional[torch.Tensor],
|
||||
quantized_representation: Optional[torch.Tensor] = None,
|
||||
audio_codes: Optional[torch.Tensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Args:
|
||||
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`):
|
||||
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
|
||||
Quantized continuous representation of input.
|
||||
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
|
||||
The codebook indices for each codebook, representing the quantized discrete
|
||||
|
||||
@@ -458,9 +458,9 @@ class DacIntegrationTest(unittest.TestCase):
|
||||
expected_rmse = 0.0039
|
||||
|
||||
expected_encoder_output_dict = {
|
||||
"quantized_representation": torch.tensor([0.9807, 2.8212, 5.2514, 2.7241, 1.0426]),
|
||||
"quantized_representation": torch.tensor([0.6257, 3.1245, 5.2514, 2.3160, 1.5774]),
|
||||
"audio_codes": torch.tensor([919, 919, 234, 777, 234]),
|
||||
"projected_latents": torch.tensor([-4.7822, -5.0046, -4.5574, -5.0363, -5.4271]),
|
||||
"projected_latents": torch.tensor([-4.7841, -5.0063, -4.5595, -5.0372, -5.4280]),
|
||||
}
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
@@ -507,6 +507,11 @@ class DacIntegrationTest(unittest.TestCase):
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
input_values_from_codes = model.decode(audio_codes=encoder_outputs.audio_codes)[0]
|
||||
|
||||
# make sure decode from audio codes and quantized values give more or less the same results
|
||||
self.assertTrue(torch.allclose(input_values_from_codes, input_values_dec, atol=1e-5))
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user