From dfbd209c2586e9eb4f55d1c4f5e6a9e0cbca8f60 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Tue, 28 Nov 2023 22:10:01 +0530 Subject: [PATCH] CLVP Fixes (#27547) * fixes * more fixes * style fix * more fix * comments --- src/transformers/models/clvp/modeling_clvp.py | 117 +++++++++++++++--- tests/models/clvp/test_modeling_clvp.py | 21 ++-- 2 files changed, 106 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index db2bbe3f00..64c6927e4a 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -81,8 +81,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, v, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -107,7 +106,51 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed + v_embed = (v * cos) + (rotate_half(v) * sin) + return q_embed, k_embed, v_embed + + +def _pad_extra_bos_eos_tokens( + input_ids, + attention_mask=None, + pad_token_id=0, + bos_token_id=255, + eos_token_id=0, + add_bos_token=True, + add_eos_token=True, +): + """ + This method adds extra bos and eos tokens to input_ids and accordingly modifies the attention_mask which is used in + `ClvpConditioningEncoder` and the generation loop of the `ClvpModelForConditionalGeneration`. + """ + + # add the bos token at the beginning + if add_bos_token: + input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=bos_token_id) + attention_mask = ( + torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask + ) + + modified_input_ids = input_ids + if add_eos_token: + modified_input_ids = torch.zeros( + (input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype, device=input_ids.device + ) + for i, each_input_id in enumerate(input_ids): + # locate where the valid tokens end and then add the eos token + if torch.isin(each_input_id, pad_token_id).sum(): + pos = torch.where(each_input_id == pad_token_id)[0].min() + modified_input_ids[i] = torch.concatenate( + [each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]] + ) + else: + # if there are no pad tokens present, then add eos to the end + modified_input_ids[i] = torch.nn.functional.pad(each_input_id, (0, 1), value=eos_token_id) + attention_mask = ( + torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask + ) + + return modified_input_ids, attention_mask @dataclass @@ -312,13 +355,18 @@ class ClvpSelfAttention(nn.Module): key_states[..., :rotary_emb_dim], key_states[..., rotary_emb_dim:], ) + value_rot, value_pass = ( + value_states[..., :rotary_emb_dim], + value_states[..., rotary_emb_dim:], + ) cos, sin = rotary_pos_emb.cos().squeeze(0), rotary_pos_emb.sin().squeeze(0) - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + query_rot, key_rot, value_rot = apply_rotary_pos_emb(query_rot, key_rot, value_rot, cos, sin, position_ids) # [batch_size, num_heads, seq_length, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) + value_states = torch.cat((value_rot, value_pass), dim=-1) tgt_len = query_states.shape[2] src_len = key_states.shape[2] @@ -599,16 +647,7 @@ class ClvpConditioningEncoder(nn.Module): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - # We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple - # This logic is specific to ClvpConditioningEncoder and not used by other modules. - input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=self.text_config.bos_token_id) - input_ids = torch.nn.functional.pad(input_ids, (0, 1), value=self.text_config.eos_token_id) batch_size, seq_length = input_ids.size() - inputs_embeds = self.text_token_embedding(input_ids) - # check if we need to update attention mask, if yes then pad it too - if attention_mask is not None and attention_mask.shape[1] != seq_length: - attention_mask = torch.nn.functional.pad(attention_mask, (1, 0), value=1) - attention_mask = torch.nn.functional.pad(attention_mask, (0, 1), value=1) elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.size()[:-1] else: @@ -616,8 +655,18 @@ class ClvpConditioningEncoder(nn.Module): # construct attention mask if not given if attention_mask is None: - attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=inputs_embeds.device) + attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=input_ids.device) + # We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple + # This logic is specific to ClvpConditioningEncoder and not used by other modules. + input_ids, attention_mask = _pad_extra_bos_eos_tokens( + input_ids, + attention_mask, + bos_token_id=self.text_config.bos_token_id, + eos_token_id=self.text_config.eos_token_id, + ) + + inputs_embeds = self.text_token_embedding(input_ids) position_ids = attention_mask.cumsum(-1) - 1 position_embeds = self.text_position_embedding(position_ids) text_embeds = inputs_embeds + position_embeds @@ -1512,10 +1561,6 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): """ decoder_fixing_codes = self.config.decoder_config.decoder_fixing_codes speech_ids = speech_ids[:, 1:] - if torch.isin(self.speech_decoder_model.config.eos_token_id, speech_ids): - speech_ids = torch.nn.functional.pad( - speech_ids, pad=(0, 1), value=self.speech_decoder_model.config.eos_token_id - ) stop_token_indices = torch.where(speech_ids == self.speech_decoder_model.config.eos_token_id, 1, 0) speech_ids = torch.masked_fill(speech_ids, mask=stop_token_indices.bool(), value=decoder_fixing_codes[0]) @@ -1828,6 +1873,7 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): input_features: torch.FloatTensor = None, attention_mask: Optional[torch.LongTensor] = None, generation_config: Optional[GenerationConfig] = None, + pad_to_max_mel_tokens: Optional[int] = None, output_hidden_states: Optional[bool] = None, **kwargs, ): @@ -1855,6 +1901,11 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. + pad_to_max_mel_tokens (`int`, *optional*): + Pads generated speech_ids to the specified value. This is to implement the same logic from the official + repo, link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430 + and to make sure the logits are same. + This does not affect generation quality so please don't consider using it since it is less efficient. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of decoder model, text encoder and speech encoder models. @@ -1862,6 +1913,17 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): `ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a tuple. """ + + # If the input sequences are larger than (self.config.decoder_config.max_text_tokens - 3) then raise error, + # because we need to add 3 tokens ( 1 bos tokens and 2 eos tokens) to the input_ids in ClvpConditioningEncoder to + # properly sample + sequence_length = input_ids.shape[-1] + if sequence_length > (self.config.decoder_config.max_text_tokens - 3): + raise ValueError( + f"Maximum sequence length reached! Found input_ids of length {sequence_length}." + f"Please make sure that the maximum length of input_ids is {self.config.decoder_config.max_text_tokens - 3}" + ) + if generation_config is None: generation_config = self.generation_config @@ -1870,6 +1932,16 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) + # pad input_ids as specified in the original repo + # link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L380 + input_ids, attention_mask = _pad_extra_bos_eos_tokens( + input_ids, + attention_mask, + add_bos_token=False, + bos_token_id=self.config.text_config.bos_token_id, + eos_token_id=self.config.text_config.eos_token_id, + ) + conditioning_embeds = self.conditioning_encoder( input_features=input_features, input_ids=input_ids, @@ -1884,6 +1956,15 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): ) if isinstance(decoder_outputs, ModelOutput): speech_ids = decoder_outputs.sequences + + # pad to pad_to_max_mel_tokens if given, to replicate the original repo logic + # link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430 + if pad_to_max_mel_tokens is not None: + padding_needed = pad_to_max_mel_tokens - speech_ids.shape[-1] + speech_ids = torch.nn.functional.pad( + speech_ids, (0, padding_needed), value=self.generation_config.eos_token_id + ) + speech_ids = self.fix_speech_decoder_output(speech_ids) speech_outputs = self.speech_encoder_model( diff --git a/tests/models/clvp/test_modeling_clvp.py b/tests/models/clvp/test_modeling_clvp.py index 3ebe5fe357..e27d9e08eb 100644 --- a/tests/models/clvp/test_modeling_clvp.py +++ b/tests/models/clvp/test_modeling_clvp.py @@ -604,12 +604,7 @@ class ClvpIntegrationTest(unittest.TestCase): text_embeds = self.model.text_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu() # fmt: off - EXPECTED_TEXT_EMBEDS = torch.tensor( - [ 1.8060e+00, -2.7928e+00, 3.2021e+00, -1.5673e+00, 2.3284e+00, -3.2065e+00, -1.3368e+00, 2.2322e+00, - -1.7667e+00, 4.1505e-01, 2.4119e+00, -5.8133e-03, -4.6367e+00, 1.6450e-01, 6.7459e+00, 6.6292e+00, - 1.1046e+00, 3.6196e+00, -1.0496e+01, 5.4924e+00 - ] - ) + EXPECTED_TEXT_EMBEDS = torch.tensor([1.4798, -2.0005, 2.3902, -0.5042, 1.6401, -2.4135, -1.4800, 3.0118, -2.4422, 1.3266, 2.2339, 1.4761, -4.8983, -1.3592, 6.0251, 6.7364, 2.2576, 3.7229, -10.0436, 4.6676]) # fmt: on self.assertTrue(torch.allclose(text_embeds[0, :20], EXPECTED_TEXT_EMBEDS, atol=1e-4)) @@ -618,11 +613,7 @@ class ClvpIntegrationTest(unittest.TestCase): speech_embeds = self.model.speech_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu() # fmt: off - EXPECTED_SPEECH_EMBEDS = torch.tensor( - [ 4.6143, -5.5784, 0.8983, -3.9665, -0.6714, -1.0665, -1.1277, 1.5619, 2.6322, -7.2008, -2.4932, 0.3265, - -1.4738, 0.1425, 5.0825, 4.1760, -5.4708, 2.1935, -6.0044, 3.9540 - ] - ) + EXPECTED_SPEECH_EMBEDS = torch.tensor([3.1202, -3.1183, -1.4264, -6.1339, 1.8885, -0.1983, 0.9461, -1.7414, 0.3320, -3.8400, -1.5715, 1.5096, -1.7576, 0.2387, 4.9758, 5.8450, -6.2534, 2.8587, -5.5816, 4.7821]) # fmt: on self.assertTrue(torch.allclose(speech_embeds[0, :20], EXPECTED_SPEECH_EMBEDS, atol=1e-4)) @@ -635,8 +626,10 @@ class ClvpIntegrationTest(unittest.TestCase): num_beams=4, num_return_sequences=4, max_new_tokens=10, - ).speech_ids.cpu() + ) - EXPECTED_OUTPUTS = torch.tensor([[1953, 1080, 612], [1953, 1953, 612], [1953, 612, 716]]) + EXPECTED_SPEECH_IDS = torch.tensor([[1953, 1080, 612], [1953, 612, 493], [1953, 612, 716]]) + EXPECTED_SIMILARITY_SCORES = torch.tensor([[14.7660, 14.4569, 13.6472, 13.5683]]) - self.assertTrue(torch.allclose(full_model_output[-3:, -3:], EXPECTED_OUTPUTS)) + self.assertTrue(torch.allclose(full_model_output.speech_ids.cpu()[-3:, -3:], EXPECTED_SPEECH_IDS)) + self.assertTrue(torch.allclose(full_model_output.logits_per_text.cpu(), EXPECTED_SIMILARITY_SCORES))