CLVP Fixes (#27547)
* fixes * more fixes * style fix * more fix * comments
This commit is contained in:
@@ -81,8 +81,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
def apply_rotary_pos_emb(q, k, v, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@@ -107,7 +106,51 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
v_embed = (v * cos) + (rotate_half(v) * sin)
|
||||
return q_embed, k_embed, v_embed
|
||||
|
||||
|
||||
def _pad_extra_bos_eos_tokens(
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
pad_token_id=0,
|
||||
bos_token_id=255,
|
||||
eos_token_id=0,
|
||||
add_bos_token=True,
|
||||
add_eos_token=True,
|
||||
):
|
||||
"""
|
||||
This method adds extra bos and eos tokens to input_ids and accordingly modifies the attention_mask which is used in
|
||||
`ClvpConditioningEncoder` and the generation loop of the `ClvpModelForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
# add the bos token at the beginning
|
||||
if add_bos_token:
|
||||
input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=bos_token_id)
|
||||
attention_mask = (
|
||||
torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
|
||||
)
|
||||
|
||||
modified_input_ids = input_ids
|
||||
if add_eos_token:
|
||||
modified_input_ids = torch.zeros(
|
||||
(input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
for i, each_input_id in enumerate(input_ids):
|
||||
# locate where the valid tokens end and then add the eos token
|
||||
if torch.isin(each_input_id, pad_token_id).sum():
|
||||
pos = torch.where(each_input_id == pad_token_id)[0].min()
|
||||
modified_input_ids[i] = torch.concatenate(
|
||||
[each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]]
|
||||
)
|
||||
else:
|
||||
# if there are no pad tokens present, then add eos to the end
|
||||
modified_input_ids[i] = torch.nn.functional.pad(each_input_id, (0, 1), value=eos_token_id)
|
||||
attention_mask = (
|
||||
torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
|
||||
)
|
||||
|
||||
return modified_input_ids, attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -312,13 +355,18 @@ class ClvpSelfAttention(nn.Module):
|
||||
key_states[..., :rotary_emb_dim],
|
||||
key_states[..., rotary_emb_dim:],
|
||||
)
|
||||
value_rot, value_pass = (
|
||||
value_states[..., :rotary_emb_dim],
|
||||
value_states[..., rotary_emb_dim:],
|
||||
)
|
||||
|
||||
cos, sin = rotary_pos_emb.cos().squeeze(0), rotary_pos_emb.sin().squeeze(0)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
query_rot, key_rot, value_rot = apply_rotary_pos_emb(query_rot, key_rot, value_rot, cos, sin, position_ids)
|
||||
|
||||
# [batch_size, num_heads, seq_length, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||
value_states = torch.cat((value_rot, value_pass), dim=-1)
|
||||
|
||||
tgt_len = query_states.shape[2]
|
||||
src_len = key_states.shape[2]
|
||||
@@ -599,16 +647,7 @@ class ClvpConditioningEncoder(nn.Module):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
|
||||
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
|
||||
input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=self.text_config.bos_token_id)
|
||||
input_ids = torch.nn.functional.pad(input_ids, (0, 1), value=self.text_config.eos_token_id)
|
||||
batch_size, seq_length = input_ids.size()
|
||||
inputs_embeds = self.text_token_embedding(input_ids)
|
||||
# check if we need to update attention mask, if yes then pad it too
|
||||
if attention_mask is not None and attention_mask.shape[1] != seq_length:
|
||||
attention_mask = torch.nn.functional.pad(attention_mask, (1, 0), value=1)
|
||||
attention_mask = torch.nn.functional.pad(attention_mask, (0, 1), value=1)
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
@@ -616,8 +655,18 @@ class ClvpConditioningEncoder(nn.Module):
|
||||
|
||||
# construct attention mask if not given
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=inputs_embeds.device)
|
||||
attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=input_ids.device)
|
||||
|
||||
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
|
||||
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
|
||||
input_ids, attention_mask = _pad_extra_bos_eos_tokens(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
bos_token_id=self.text_config.bos_token_id,
|
||||
eos_token_id=self.text_config.eos_token_id,
|
||||
)
|
||||
|
||||
inputs_embeds = self.text_token_embedding(input_ids)
|
||||
position_ids = attention_mask.cumsum(-1) - 1
|
||||
position_embeds = self.text_position_embedding(position_ids)
|
||||
text_embeds = inputs_embeds + position_embeds
|
||||
@@ -1512,10 +1561,6 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
|
||||
"""
|
||||
decoder_fixing_codes = self.config.decoder_config.decoder_fixing_codes
|
||||
speech_ids = speech_ids[:, 1:]
|
||||
if torch.isin(self.speech_decoder_model.config.eos_token_id, speech_ids):
|
||||
speech_ids = torch.nn.functional.pad(
|
||||
speech_ids, pad=(0, 1), value=self.speech_decoder_model.config.eos_token_id
|
||||
)
|
||||
|
||||
stop_token_indices = torch.where(speech_ids == self.speech_decoder_model.config.eos_token_id, 1, 0)
|
||||
speech_ids = torch.masked_fill(speech_ids, mask=stop_token_indices.bool(), value=decoder_fixing_codes[0])
|
||||
@@ -1828,6 +1873,7 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
|
||||
input_features: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
pad_to_max_mel_tokens: Optional[int] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -1855,6 +1901,11 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
|
||||
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
||||
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
||||
default values, whose documentation should be checked to parameterize generation.
|
||||
pad_to_max_mel_tokens (`int`, *optional*):
|
||||
Pads generated speech_ids to the specified value. This is to implement the same logic from the official
|
||||
repo, link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
|
||||
and to make sure the logits are same.
|
||||
This does not affect generation quality so please don't consider using it since it is less efficient.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of decoder model, text encoder and speech encoder models.
|
||||
|
||||
@@ -1862,6 +1913,17 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
|
||||
`ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when
|
||||
`config.return_dict_in_generate=True`) or a tuple.
|
||||
"""
|
||||
|
||||
# If the input sequences are larger than (self.config.decoder_config.max_text_tokens - 3) then raise error,
|
||||
# because we need to add 3 tokens ( 1 bos tokens and 2 eos tokens) to the input_ids in ClvpConditioningEncoder to
|
||||
# properly sample
|
||||
sequence_length = input_ids.shape[-1]
|
||||
if sequence_length > (self.config.decoder_config.max_text_tokens - 3):
|
||||
raise ValueError(
|
||||
f"Maximum sequence length reached! Found input_ids of length {sequence_length}."
|
||||
f"Please make sure that the maximum length of input_ids is {self.config.decoder_config.max_text_tokens - 3}"
|
||||
)
|
||||
|
||||
if generation_config is None:
|
||||
generation_config = self.generation_config
|
||||
|
||||
@@ -1870,6 +1932,16 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
|
||||
generation_config.validate()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# pad input_ids as specified in the original repo
|
||||
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L380
|
||||
input_ids, attention_mask = _pad_extra_bos_eos_tokens(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
add_bos_token=False,
|
||||
bos_token_id=self.config.text_config.bos_token_id,
|
||||
eos_token_id=self.config.text_config.eos_token_id,
|
||||
)
|
||||
|
||||
conditioning_embeds = self.conditioning_encoder(
|
||||
input_features=input_features,
|
||||
input_ids=input_ids,
|
||||
@@ -1884,6 +1956,15 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
|
||||
)
|
||||
if isinstance(decoder_outputs, ModelOutput):
|
||||
speech_ids = decoder_outputs.sequences
|
||||
|
||||
# pad to pad_to_max_mel_tokens if given, to replicate the original repo logic
|
||||
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
|
||||
if pad_to_max_mel_tokens is not None:
|
||||
padding_needed = pad_to_max_mel_tokens - speech_ids.shape[-1]
|
||||
speech_ids = torch.nn.functional.pad(
|
||||
speech_ids, (0, padding_needed), value=self.generation_config.eos_token_id
|
||||
)
|
||||
|
||||
speech_ids = self.fix_speech_decoder_output(speech_ids)
|
||||
|
||||
speech_outputs = self.speech_encoder_model(
|
||||
|
||||
@@ -604,12 +604,7 @@ class ClvpIntegrationTest(unittest.TestCase):
|
||||
text_embeds = self.model.text_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu()
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TEXT_EMBEDS = torch.tensor(
|
||||
[ 1.8060e+00, -2.7928e+00, 3.2021e+00, -1.5673e+00, 2.3284e+00, -3.2065e+00, -1.3368e+00, 2.2322e+00,
|
||||
-1.7667e+00, 4.1505e-01, 2.4119e+00, -5.8133e-03, -4.6367e+00, 1.6450e-01, 6.7459e+00, 6.6292e+00,
|
||||
1.1046e+00, 3.6196e+00, -1.0496e+01, 5.4924e+00
|
||||
]
|
||||
)
|
||||
EXPECTED_TEXT_EMBEDS = torch.tensor([1.4798, -2.0005, 2.3902, -0.5042, 1.6401, -2.4135, -1.4800, 3.0118, -2.4422, 1.3266, 2.2339, 1.4761, -4.8983, -1.3592, 6.0251, 6.7364, 2.2576, 3.7229, -10.0436, 4.6676])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(text_embeds[0, :20], EXPECTED_TEXT_EMBEDS, atol=1e-4))
|
||||
@@ -618,11 +613,7 @@ class ClvpIntegrationTest(unittest.TestCase):
|
||||
speech_embeds = self.model.speech_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu()
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_SPEECH_EMBEDS = torch.tensor(
|
||||
[ 4.6143, -5.5784, 0.8983, -3.9665, -0.6714, -1.0665, -1.1277, 1.5619, 2.6322, -7.2008, -2.4932, 0.3265,
|
||||
-1.4738, 0.1425, 5.0825, 4.1760, -5.4708, 2.1935, -6.0044, 3.9540
|
||||
]
|
||||
)
|
||||
EXPECTED_SPEECH_EMBEDS = torch.tensor([3.1202, -3.1183, -1.4264, -6.1339, 1.8885, -0.1983, 0.9461, -1.7414, 0.3320, -3.8400, -1.5715, 1.5096, -1.7576, 0.2387, 4.9758, 5.8450, -6.2534, 2.8587, -5.5816, 4.7821])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(speech_embeds[0, :20], EXPECTED_SPEECH_EMBEDS, atol=1e-4))
|
||||
@@ -635,8 +626,10 @@ class ClvpIntegrationTest(unittest.TestCase):
|
||||
num_beams=4,
|
||||
num_return_sequences=4,
|
||||
max_new_tokens=10,
|
||||
).speech_ids.cpu()
|
||||
)
|
||||
|
||||
EXPECTED_OUTPUTS = torch.tensor([[1953, 1080, 612], [1953, 1953, 612], [1953, 612, 716]])
|
||||
EXPECTED_SPEECH_IDS = torch.tensor([[1953, 1080, 612], [1953, 612, 493], [1953, 612, 716]])
|
||||
EXPECTED_SIMILARITY_SCORES = torch.tensor([[14.7660, 14.4569, 13.6472, 13.5683]])
|
||||
|
||||
self.assertTrue(torch.allclose(full_model_output[-3:, -3:], EXPECTED_OUTPUTS))
|
||||
self.assertTrue(torch.allclose(full_model_output.speech_ids.cpu()[-3:, -3:], EXPECTED_SPEECH_IDS))
|
||||
self.assertTrue(torch.allclose(full_model_output.logits_per_text.cpu(), EXPECTED_SIMILARITY_SCORES))
|
||||
|
||||
Reference in New Issue
Block a user