🔴 Fix EnCodec internals and integration tests (#39431)
* EnCodec fixes and update integration tests. * Apply padding mask when normalize is False. * Update comment of copied function. * Fix padding mask within modeling. * Revert padding function. * Simplify handling of padding_mask. * Address variable codebook size. * Add output for padding for consistency with original model, fix docstrings. * last_frame_pad_length as int * Update example code. * Improve docstring/comments. * Shorten expected output. * Consistent docstring. * Parameterize tests. * Properties for derived variables. * Update expected outputs from GitHub runner. * Consistent outputs with runner GPUs.
This commit is contained in:
@@ -47,7 +47,8 @@ Here is a quick example of how to encode and decode an audio using this model:
|
|||||||
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
|
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
|
||||||
|
|
||||||
>>> encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
|
>>> encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
|
||||||
>>> audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0]
|
>>> # `encoder_outputs.audio_codes` contains discrete codes
|
||||||
|
>>> audio_values = model.decode(**encoder_outputs, padding_mask=inputs["padding_mask"])[0]
|
||||||
>>> # or the equivalent with a forward pass
|
>>> # or the equivalent with a forward pass
|
||||||
>>> audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values
|
>>> audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -179,14 +179,21 @@ class EncodecConfig(PretrainedConfig):
|
|||||||
else:
|
else:
|
||||||
return max(1, int((1.0 - self.overlap) * self.chunk_length))
|
return max(1, int((1.0 - self.overlap) * self.chunk_length))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hop_length(self) -> int:
|
||||||
|
return int(np.prod(self.upsampling_ratios))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def codebook_nbits(self) -> int:
|
||||||
|
return math.ceil(math.log2(self.codebook_size))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def frame_rate(self) -> int:
|
def frame_rate(self) -> int:
|
||||||
hop_length = np.prod(self.upsampling_ratios)
|
return math.ceil(self.sampling_rate / self.hop_length)
|
||||||
return math.ceil(self.sampling_rate / hop_length)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_quantizers(self) -> int:
|
def num_quantizers(self) -> int:
|
||||||
return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * 10))
|
return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * self.codebook_nbits))
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["EncodecConfig"]
|
__all__ = ["EncodecConfig"]
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from dataclasses import dataclass
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
@@ -41,8 +40,8 @@ logger = logging.get_logger(__name__)
|
|||||||
@auto_docstring
|
@auto_docstring
|
||||||
class EncodecOutput(ModelOutput):
|
class EncodecOutput(ModelOutput):
|
||||||
r"""
|
r"""
|
||||||
audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
|
audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
|
||||||
Discret code embeddings computed using `model.encode`.
|
Discrete code embeddings computed using `model.encode`.
|
||||||
audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*):
|
audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*):
|
||||||
Decoded audio values, obtained using the decoder part of Encodec.
|
Decoded audio values, obtained using the decoder part of Encodec.
|
||||||
"""
|
"""
|
||||||
@@ -55,14 +54,19 @@ class EncodecOutput(ModelOutput):
|
|||||||
@auto_docstring
|
@auto_docstring
|
||||||
class EncodecEncoderOutput(ModelOutput):
|
class EncodecEncoderOutput(ModelOutput):
|
||||||
r"""
|
r"""
|
||||||
audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
|
audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
|
||||||
Discret code embeddings computed using `model.encode`.
|
Discrete code embeddings computed using `model.encode`.
|
||||||
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
|
audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
|
||||||
Scaling factor for each `audio_codes` input. This is used to unscale each chunk of audio when decoding.
|
Scaling factor for each `audio_codes` input. This is used to unscale each chunk of audio when decoding.
|
||||||
|
last_frame_pad_length (`int`, *optional*):
|
||||||
|
The length of the padding in the last frame, if any. This is used to ensure that the encoded frames can be
|
||||||
|
outputted as a tensor. This value should be passed during decoding to ensure padding is removed from the
|
||||||
|
encoded frames.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio_codes: Optional[torch.LongTensor] = None
|
audio_codes: Optional[torch.LongTensor] = None
|
||||||
audio_scales: Optional[torch.FloatTensor] = None
|
audio_scales: Optional[torch.FloatTensor] = None
|
||||||
|
last_frame_pad_length: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -499,7 +503,7 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
def _encode_frame(
|
def _encode_frame(
|
||||||
self, input_values: torch.Tensor, bandwidth: float, padding_mask: int
|
self, input_values: torch.Tensor, bandwidth: float
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
|
Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
|
||||||
@@ -513,11 +517,10 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
|
|
||||||
scale = None
|
scale = None
|
||||||
if self.config.normalize:
|
if self.config.normalize:
|
||||||
# if the padding is non zero
|
|
||||||
input_values = input_values * padding_mask.unsqueeze(1)
|
|
||||||
mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1]
|
mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1]
|
||||||
scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8
|
scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8
|
||||||
input_values = input_values / scale
|
input_values = input_values / scale
|
||||||
|
scale = scale.view(-1, 1)
|
||||||
|
|
||||||
embeddings = self.encoder(input_values)
|
embeddings = self.encoder(input_values)
|
||||||
codes = self.quantizer.encode(embeddings, bandwidth)
|
codes = self.quantizer.encode(embeddings, bandwidth)
|
||||||
@@ -530,9 +533,18 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
bandwidth: Optional[float] = None,
|
bandwidth: Optional[float] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], EncodecEncoderOutput]:
|
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor], int], EncodecEncoderOutput]:
|
||||||
"""
|
"""
|
||||||
Encodes the input audio waveform into discrete codes.
|
Encodes the input audio waveform into discrete codes of shape
|
||||||
|
`(nb_frames, batch_size, nb_quantizers, frame_len)`.
|
||||||
|
|
||||||
|
- `nb_frames=1` if `self.config.chunk_length=None` (as the encoder is applied on the full audio), which is the
|
||||||
|
case for the 24kHz model. Otherwise, `nb_frames=ceil(input_length/self.config.chunk_stride)`, which is the case
|
||||||
|
for the 48kHz model.
|
||||||
|
- `frame_len` is the length of each frame, which is equal to `ceil(input_length/self.config.hop_length)` if
|
||||||
|
`self.config.chunk_length=None` (e.g., for the 24kHz model). Otherwise, if `self.config.chunk_length` is
|
||||||
|
defined, `frame_len=self.config.chunk_length/self.config.hop_length`, e.g., the case for the 48kHz model with
|
||||||
|
`frame_len=150`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
||||||
@@ -545,9 +557,10 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
as bandwidth == 6.0
|
as bandwidth == 6.0
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
|
EncodecEncoderOutput dict or a tuple containing:
|
||||||
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
|
- audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*),
|
||||||
`codebook` of shape `[batch_size, num_codebooks, frames]`.
|
- audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*),
|
||||||
|
- last_frame_pad_length (`int`, *optional*).
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
|
|
||||||
@@ -572,29 +585,28 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
|
|
||||||
if padding_mask is None:
|
if padding_mask is None:
|
||||||
padding_mask = torch.ones_like(input_values).bool()
|
padding_mask = torch.ones_like(input_values).bool()
|
||||||
|
else:
|
||||||
|
padding_mask = padding_mask.view(padding_mask.shape[0], -1, padding_mask.shape[-1])
|
||||||
|
|
||||||
encoded_frames = []
|
encoded_frames = []
|
||||||
scales = []
|
scales = []
|
||||||
|
for offset in range(0, input_length, stride):
|
||||||
step = chunk_length - stride
|
|
||||||
if (input_length % stride) - step != 0:
|
|
||||||
raise ValueError(
|
|
||||||
"The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly."
|
|
||||||
)
|
|
||||||
|
|
||||||
for offset in range(0, input_length - step, stride):
|
|
||||||
mask = padding_mask[..., offset : offset + chunk_length].bool()
|
mask = padding_mask[..., offset : offset + chunk_length].bool()
|
||||||
frame = input_values[:, :, offset : offset + chunk_length]
|
frame = mask * input_values[..., offset : offset + chunk_length]
|
||||||
encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
|
encoded_frame, scale = self._encode_frame(frame, bandwidth)
|
||||||
encoded_frames.append(encoded_frame)
|
encoded_frames.append(encoded_frame)
|
||||||
scales.append(scale)
|
scales.append(scale)
|
||||||
|
|
||||||
|
# pad last frame (if necessary) to be able to apply `torch.stack`
|
||||||
|
last_frame_pad_length = encoded_frames[0].shape[-1] - encoded_frames[-1].shape[-1]
|
||||||
|
if last_frame_pad_length > 0:
|
||||||
|
last_frame = nn.functional.pad(encoded_frames[-1], (0, last_frame_pad_length), value=0)
|
||||||
|
encoded_frames[-1] = last_frame
|
||||||
encoded_frames = torch.stack(encoded_frames)
|
encoded_frames = torch.stack(encoded_frames)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (encoded_frames, scales)
|
return (encoded_frames, scales, last_frame_pad_length)
|
||||||
|
return EncodecEncoderOutput(encoded_frames, scales, last_frame_pad_length)
|
||||||
return EncodecEncoderOutput(encoded_frames, scales)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _linear_overlap_add(frames: list[torch.Tensor], stride: int):
|
def _linear_overlap_add(frames: list[torch.Tensor], stride: int):
|
||||||
@@ -657,6 +669,7 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
audio_scales: torch.Tensor,
|
audio_scales: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
last_frame_pad_length: Optional[int] = 0,
|
||||||
) -> Union[tuple[torch.Tensor, torch.Tensor], EncodecDecoderOutput]:
|
) -> Union[tuple[torch.Tensor, torch.Tensor], EncodecDecoderOutput]:
|
||||||
"""
|
"""
|
||||||
Decodes the given frames into an output audio waveform.
|
Decodes the given frames into an output audio waveform.
|
||||||
@@ -665,14 +678,16 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
trimmed.
|
trimmed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
|
audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
|
||||||
Discret code embeddings computed using `model.encode`.
|
Discrete code embeddings computed using `model.encode`.
|
||||||
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
|
audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
|
||||||
Scaling factor for each `audio_codes` input.
|
Scaling factor for each `audio_codes` input.
|
||||||
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
padding_mask (`torch.Tensor` of shape `(channels, sequence_length)`):
|
||||||
Padding mask used to pad the `input_values`.
|
Padding mask used to pad the `input_values`.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
last_frame_pad_length (`int`, *optional*):
|
||||||
|
Integer representing the length of the padding in the last frame, which is removed during decoding.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
@@ -681,11 +696,15 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
if chunk_length is None:
|
if chunk_length is None:
|
||||||
if len(audio_codes) != 1:
|
if len(audio_codes) != 1:
|
||||||
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
|
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
|
||||||
audio_values = self._decode_frame(audio_codes[0], audio_scales[0])
|
frame = audio_codes[0]
|
||||||
|
if last_frame_pad_length > 0:
|
||||||
|
frame = frame[..., :-last_frame_pad_length]
|
||||||
|
audio_values = self._decode_frame(frame, audio_scales[0])
|
||||||
else:
|
else:
|
||||||
decoded_frames = []
|
decoded_frames = []
|
||||||
|
for i, (frame, scale) in enumerate(zip(audio_codes, audio_scales)):
|
||||||
for frame, scale in zip(audio_codes, audio_scales):
|
if i == len(audio_codes) - 1 and last_frame_pad_length > 0:
|
||||||
|
frame = frame[..., :-last_frame_pad_length]
|
||||||
frames = self._decode_frame(frame, scale)
|
frames = self._decode_frame(frame, scale)
|
||||||
decoded_frames.append(frames)
|
decoded_frames.append(frames)
|
||||||
|
|
||||||
@@ -708,6 +727,7 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
audio_codes: Optional[torch.LongTensor] = None,
|
audio_codes: Optional[torch.LongTensor] = None,
|
||||||
audio_scales: Optional[torch.Tensor] = None,
|
audio_scales: Optional[torch.Tensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
last_frame_pad_length: Optional[int] = 0,
|
||||||
) -> Union[tuple[torch.Tensor, torch.Tensor], EncodecOutput]:
|
) -> Union[tuple[torch.Tensor, torch.Tensor], EncodecOutput]:
|
||||||
r"""
|
r"""
|
||||||
input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
|
input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
|
||||||
@@ -731,10 +751,16 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
|
The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
|
||||||
bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
|
bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
|
||||||
`bandwidth == 6.0`
|
`bandwidth == 6.0`
|
||||||
audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
|
audio_codes (`torch.LongTensor` of shape `(nb_frames, batch_size, nb_quantizers, frame_len)`, *optional*):
|
||||||
Discret code embeddings computed using `model.encode`.
|
Discrete code embeddings computed using `model.encode`.
|
||||||
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
|
audio_scales (list of length `nb_frames` of `torch.Tensor` of shape `(batch_size, 1)`, *optional*):
|
||||||
Scaling factor for each `audio_codes` input.
|
Scaling factor for each `audio_codes` input.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether to return outputs as a dict.
|
||||||
|
last_frame_pad_length (`int`, *optional*):
|
||||||
|
The length of the padding in the last frame, if any. This is used to ensure that the encoded frames can be
|
||||||
|
outputted as a tensor. This value should be passed during decoding to ensure padding is removed from the
|
||||||
|
encoded frames.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -759,6 +785,9 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
|
|
||||||
if padding_mask is None:
|
if padding_mask is None:
|
||||||
padding_mask = torch.ones_like(input_values).bool()
|
padding_mask = torch.ones_like(input_values).bool()
|
||||||
|
else:
|
||||||
|
# ensure that channel dimension is present
|
||||||
|
padding_mask = padding_mask.view(padding_mask.shape[0], -1, padding_mask.shape[-1])
|
||||||
|
|
||||||
if audio_codes is not None and audio_scales is None:
|
if audio_codes is not None and audio_scales is None:
|
||||||
raise ValueError("You specified `audio_codes` but did not specify the `audio_scales`")
|
raise ValueError("You specified `audio_codes` but did not specify the `audio_scales`")
|
||||||
@@ -767,9 +796,17 @@ class EncodecModel(EncodecPreTrainedModel):
|
|||||||
raise ValueError("You specified `audio_scales` but did not specify the `audio_codes`")
|
raise ValueError("You specified `audio_scales` but did not specify the `audio_codes`")
|
||||||
|
|
||||||
if audio_scales is None and audio_codes is None:
|
if audio_scales is None and audio_codes is None:
|
||||||
audio_codes, audio_scales = self.encode(input_values, padding_mask, bandwidth, False)
|
audio_codes, audio_scales, last_frame_pad_length = self.encode(
|
||||||
|
input_values, padding_mask, bandwidth, False
|
||||||
|
)
|
||||||
|
|
||||||
audio_values = self.decode(audio_codes, audio_scales, padding_mask, return_dict=return_dict)[0]
|
audio_values = self.decode(
|
||||||
|
audio_codes,
|
||||||
|
audio_scales,
|
||||||
|
padding_mask,
|
||||||
|
return_dict=return_dict,
|
||||||
|
last_frame_pad_length=last_frame_pad_length,
|
||||||
|
)[0]
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (audio_codes, audio_values)
|
return (audio_codes, audio_values)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user