Fix contrastive search to correctly handle input with padding (#33507)
* fix: handle padding in contrastive search for decoder-only models * fix: handle padding in contrastive search for encoder-decoder models * tests: move padding contrastive test to test_util, add t5 test * fix: handle if model_kwargs["decoder_attention_mask"] is None * refactor: improve padding input contrastive search generation tests * chore: _ranking_fast to use LongTensor for cosine_matrix_mask
This commit is contained in:
@@ -44,6 +44,7 @@ from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
@@ -59,6 +60,7 @@ if is_torch_available():
|
||||
GPT2Tokenizer,
|
||||
ImageGPTForCausalImageModeling,
|
||||
SpeechEncoderDecoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
|
||||
from transformers.generation import (
|
||||
@@ -3644,6 +3646,139 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
|
||||
@slow
|
||||
def test_padding_input_contrastive_search_gpt2(self):
|
||||
# Load the pre-trained GPT-2 model and tokenizer
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
|
||||
model.to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True)
|
||||
|
||||
# Set the tokenizer to left-pad the sequences
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Define the PAD token as the EOS token
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
|
||||
# Define the input prompt
|
||||
prompt_text = "The whispered legends of the haunted mansion spoke"
|
||||
|
||||
# Tokenize the input prompt
|
||||
encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True)
|
||||
input_ids = encoded_prompt.input_ids.to(torch_device)
|
||||
attention_mask = encoded_prompt.attention_mask.to(torch_device)
|
||||
|
||||
# Define the contrastive search params
|
||||
penalty_alpha = 0.6
|
||||
top_k = 4
|
||||
|
||||
# Define the padding length to add to the input IDs and attention mask
|
||||
padding_length = 10
|
||||
|
||||
# Generate text without padding
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
penalty_alpha=penalty_alpha,
|
||||
top_k=top_k,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
# Pad the input IDs and attention mask on the left
|
||||
padded_input_ids = F.pad(
|
||||
input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
|
||||
)
|
||||
padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0)
|
||||
|
||||
# Generate text with padded inputs
|
||||
outputs_with_padding = model.generate(
|
||||
input_ids=padded_input_ids,
|
||||
attention_mask=padded_attention_mask,
|
||||
do_sample=False,
|
||||
penalty_alpha=penalty_alpha,
|
||||
top_k=top_k,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)
|
||||
|
||||
# Assert that the generated texts are identical for padded and non-padded inputs
|
||||
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
|
||||
self.assertEqual(
|
||||
generated_text_with_padding,
|
||||
'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling '
|
||||
'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been '
|
||||
'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea',
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_padding_input_contrastive_search_t5(self):
|
||||
# Load the pre-trained T5 model and tokenizer
|
||||
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
model.to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True)
|
||||
|
||||
# Define the input prompt
|
||||
prompt_text = "translate English to German: I need to finish this task before the end of the day."
|
||||
|
||||
# Tokenize the input prompt
|
||||
encoded_prompt = tokenizer(prompt_text, return_tensors="pt")
|
||||
input_ids = encoded_prompt.input_ids.to(torch_device)
|
||||
attention_mask = encoded_prompt.attention_mask.to(torch_device)
|
||||
|
||||
# Define the decoder prompt
|
||||
decoder_prompt_text = "Ich muss diese Aufgabe"
|
||||
encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device)
|
||||
decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device)
|
||||
|
||||
# Define the contrastive search params
|
||||
penalty_alpha = 0.6
|
||||
top_k = 4
|
||||
|
||||
# Generate text without padding
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
do_sample=False,
|
||||
penalty_alpha=penalty_alpha,
|
||||
top_k=top_k,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
# Define the padding length to add to the input IDs and attention mask
|
||||
padding_length = 10
|
||||
|
||||
# Pad the decoder input IDs and attention mask on the left
|
||||
padded_decoder_input_ids = F.pad(
|
||||
decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
|
||||
)
|
||||
padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0)
|
||||
# Since the decoder_start_token_id is the same as the pad_token_id,
|
||||
# the last padded token represents the decoder start token.
|
||||
# Set the attention mask for the decoder_start_token_id to True (1).
|
||||
padded_decoder_attention_mask[:, padding_length - 1] = 1
|
||||
# Generate text with padded inputs
|
||||
outputs_with_padding = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=padded_decoder_input_ids,
|
||||
decoder_attention_mask=padded_decoder_attention_mask,
|
||||
do_sample=False,
|
||||
penalty_alpha=penalty_alpha,
|
||||
top_k=top_k,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)
|
||||
|
||||
# Assert that the generated texts are identical for padded and non-padded inputs
|
||||
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
|
||||
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user