🔴 Fix EnCodec internals and integration tests (#39431)

* EnCodec fixes and update integration tests.

* Apply padding mask when normalize is False.

* Update comment of copied function.

* Fix padding mask within modeling.

* Revert padding function.

* Simplify handling of padding_mask.

* Address variable codebook size.

* Add output for padding for consistency with original model, fix docstrings.

* last_frame_pad_length as int

* Update example code.

* Improve docstring/comments.

* Shorten expected output.

* Consistent docstring.

* Parameterize tests.

* Properties for derived variables.

* Update expected outputs from GitHub runner.

* Consistent outputs with runner GPUs.
This commit is contained in:
Eric Bezzam
2025-07-23 19:39:27 +02:00
committed by GitHub
parent 7a4e2e7868
commit c5a80dd6c4
4 changed files with 1184 additions and 233 deletions

View File

@@ -47,7 +47,8 @@ Here is a quick example of how to encode and decode an audio using this model:
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
>>> encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
>>> audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0]
>>> # `encoder_outputs.audio_codes` contains discrete codes
>>> audio_values = model.decode(**encoder_outputs, padding_mask=inputs["padding_mask"])[0]
>>> # or the equivalent with a forward pass
>>> audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values
```

View File

@@ -179,14 +179,21 @@ class EncodecConfig(PretrainedConfig):
else:
return max(1, int((1.0 - self.overlap) * self.chunk_length))
@property
def hop_length(self) -> int:
return int(np.prod(self.upsampling_ratios))
@property
def codebook_nbits(self) -> int:
return math.ceil(math.log2(self.codebook_size))
@property
def frame_rate(self) -> int:
hop_length = np.prod(self.upsampling_ratios)
return math.ceil(self.sampling_rate / hop_length)
return math.ceil(self.sampling_rate / self.hop_length)
@property
def num_quantizers(self) -> int:
return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * 10))
return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * self.codebook_nbits))
__all__ = ["EncodecConfig"]

View File

@@ -19,7 +19,6 @@ from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...modeling_utils import PreTrainedModel
@@ -41,8 +40,8 @@ logger = logging.get_logger(__name__)
@auto_docstring
class EncodecOutput(ModelOutput):
r"""
audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
Discret code embeddings computed using `model.encode`.
audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
Discrete code embeddings computed using `model.encode`.
audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*):
Decoded audio values, obtained using the decoder part of Encodec.
"""
@@ -55,14 +54,19 @@ class EncodecOutput(ModelOutput):
@auto_docstring
class EncodecEncoderOutput(ModelOutput):
r"""
audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
Discret code embeddings computed using `model.encode`.
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
Discrete code embeddings computed using `model.encode`.
audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
Scaling factor for each `audio_codes` input. This is used to unscale each chunk of audio when decoding.
last_frame_pad_length (`int`, *optional*):
The length of the padding in the last frame, if any. This is used to ensure that the encoded frames can be
outputted as a tensor. This value should be passed during decoding to ensure padding is removed from the
encoded frames.
"""
audio_codes: Optional[torch.LongTensor] = None
audio_scales: Optional[torch.FloatTensor] = None
last_frame_pad_length: Optional[int] = None
@dataclass
@@ -499,7 +503,7 @@ class EncodecModel(EncodecPreTrainedModel):
return self.decoder
def _encode_frame(
self, input_values: torch.Tensor, bandwidth: float, padding_mask: int
self, input_values: torch.Tensor, bandwidth: float
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
@@ -513,11 +517,10 @@ class EncodecModel(EncodecPreTrainedModel):
scale = None
if self.config.normalize:
# if the padding is non zero
input_values = input_values * padding_mask.unsqueeze(1)
mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1]
scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8
input_values = input_values / scale
scale = scale.view(-1, 1)
embeddings = self.encoder(input_values)
codes = self.quantizer.encode(embeddings, bandwidth)
@@ -530,9 +533,18 @@ class EncodecModel(EncodecPreTrainedModel):
padding_mask: Optional[torch.Tensor] = None,
bandwidth: Optional[float] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], EncodecEncoderOutput]:
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor], int], EncodecEncoderOutput]:
"""
Encodes the input audio waveform into discrete codes.
Encodes the input audio waveform into discrete codes of shape
`(nb_frames, batch_size, nb_quantizers, frame_len)`.
- `nb_frames=1` if `self.config.chunk_length=None` (as the encoder is applied on the full audio), which is the
case for the 24kHz model. Otherwise, `nb_frames=ceil(input_length/self.config.chunk_stride)`, which is the case
for the 48kHz model.
- `frame_len` is the length of each frame, which is equal to `ceil(input_length/self.config.hop_length)` if
`self.config.chunk_length=None` (e.g., for the 24kHz model). Otherwise, if `self.config.chunk_length` is
defined, `frame_len=self.config.chunk_length/self.config.hop_length`, e.g., the case for the 48kHz model with
`frame_len=150`.
Args:
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
@@ -545,9 +557,10 @@ class EncodecModel(EncodecPreTrainedModel):
as bandwidth == 6.0
Returns:
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
`codebook` of shape `[batch_size, num_codebooks, frames]`.
EncodecEncoderOutput dict or a tuple containing:
- audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*),
- audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*),
- last_frame_pad_length (`int`, *optional*).
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
@@ -572,29 +585,28 @@ class EncodecModel(EncodecPreTrainedModel):
if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool()
else:
padding_mask = padding_mask.view(padding_mask.shape[0], -1, padding_mask.shape[-1])
encoded_frames = []
scales = []
step = chunk_length - stride
if (input_length % stride) - step != 0:
raise ValueError(
"The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly."
)
for offset in range(0, input_length - step, stride):
for offset in range(0, input_length, stride):
mask = padding_mask[..., offset : offset + chunk_length].bool()
frame = input_values[:, :, offset : offset + chunk_length]
encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
frame = mask * input_values[..., offset : offset + chunk_length]
encoded_frame, scale = self._encode_frame(frame, bandwidth)
encoded_frames.append(encoded_frame)
scales.append(scale)
# pad last frame (if necessary) to be able to apply `torch.stack`
last_frame_pad_length = encoded_frames[0].shape[-1] - encoded_frames[-1].shape[-1]
if last_frame_pad_length > 0:
last_frame = nn.functional.pad(encoded_frames[-1], (0, last_frame_pad_length), value=0)
encoded_frames[-1] = last_frame
encoded_frames = torch.stack(encoded_frames)
if not return_dict:
return (encoded_frames, scales)
return EncodecEncoderOutput(encoded_frames, scales)
return (encoded_frames, scales, last_frame_pad_length)
return EncodecEncoderOutput(encoded_frames, scales, last_frame_pad_length)
@staticmethod
def _linear_overlap_add(frames: list[torch.Tensor], stride: int):
@@ -657,6 +669,7 @@ class EncodecModel(EncodecPreTrainedModel):
audio_scales: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
last_frame_pad_length: Optional[int] = 0,
) -> Union[tuple[torch.Tensor, torch.Tensor], EncodecDecoderOutput]:
"""
Decodes the given frames into an output audio waveform.
@@ -665,14 +678,16 @@ class EncodecModel(EncodecPreTrainedModel):
trimmed.
Args:
audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
Discret code embeddings computed using `model.encode`.
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
Discrete code embeddings computed using `model.encode`.
audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
Scaling factor for each `audio_codes` input.
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
padding_mask (`torch.Tensor` of shape `(channels, sequence_length)`):
Padding mask used to pad the `input_values`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
last_frame_pad_length (`int`, *optional*):
Integer representing the length of the padding in the last frame, which is removed during decoding.
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
@@ -681,11 +696,15 @@ class EncodecModel(EncodecPreTrainedModel):
if chunk_length is None:
if len(audio_codes) != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self._decode_frame(audio_codes[0], audio_scales[0])
frame = audio_codes[0]
if last_frame_pad_length > 0:
frame = frame[..., :-last_frame_pad_length]
audio_values = self._decode_frame(frame, audio_scales[0])
else:
decoded_frames = []
for frame, scale in zip(audio_codes, audio_scales):
for i, (frame, scale) in enumerate(zip(audio_codes, audio_scales)):
if i == len(audio_codes) - 1 and last_frame_pad_length > 0:
frame = frame[..., :-last_frame_pad_length]
frames = self._decode_frame(frame, scale)
decoded_frames.append(frames)
@@ -708,6 +727,7 @@ class EncodecModel(EncodecPreTrainedModel):
audio_codes: Optional[torch.LongTensor] = None,
audio_scales: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
last_frame_pad_length: Optional[int] = 0,
) -> Union[tuple[torch.Tensor, torch.Tensor], EncodecOutput]:
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
@@ -731,10 +751,16 @@ class EncodecModel(EncodecPreTrainedModel):
The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
`bandwidth == 6.0`
audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
Discret code embeddings computed using `model.encode`.
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
Discrete code embeddings computed using `model.encode`.
audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
Scaling factor for each `audio_codes` input.
return_dict (`bool`, *optional*):
Whether to return outputs as a dict.
last_frame_pad_length (`int`, *optional*):
The length of the padding in the last frame, if any. This is used to ensure that the encoded frames can be
outputted as a tensor. This value should be passed during decoding to ensure padding is removed from the
encoded frames.
Examples:
@@ -759,6 +785,9 @@ class EncodecModel(EncodecPreTrainedModel):
if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool()
else:
# ensure that channel dimension is present
padding_mask = padding_mask.view(padding_mask.shape[0], -1, padding_mask.shape[-1])
if audio_codes is not None and audio_scales is None:
raise ValueError("You specified `audio_codes` but did not specify the `audio_scales`")
@@ -767,9 +796,17 @@ class EncodecModel(EncodecPreTrainedModel):
raise ValueError("You specified `audio_scales` but did not specify the `audio_codes`")
if audio_scales is None and audio_codes is None:
audio_codes, audio_scales = self.encode(input_values, padding_mask, bandwidth, False)
audio_codes, audio_scales, last_frame_pad_length = self.encode(
input_values, padding_mask, bandwidth, False
)
audio_values = self.decode(audio_codes, audio_scales, padding_mask, return_dict=return_dict)[0]
audio_values = self.decode(
audio_codes,
audio_scales,
padding_mask,
return_dict=return_dict,
last_frame_pad_length=last_frame_pad_length,
)[0]
if not return_dict:
return (audio_codes, audio_values)

File diff suppressed because it is too large Load Diff