Fix contrastive search to correctly handle input with padding (#33507)

* fix: handle padding in contrastive search for decoder-only models

* fix: handle padding in contrastive search for encoder-decoder models

* tests: move padding contrastive test to test_util, add t5 test

* fix: handle if model_kwargs["decoder_attention_mask"] is None

* refactor: improve padding input contrastive search generation tests

* chore: _ranking_fast to use LongTensor for cosine_matrix_mask
This commit is contained in:
Duc-Viet Hoang
2024-09-20 22:52:08 +07:00
committed by GitHub
parent c0c6815dc9
commit dc8b6eaeee
2 changed files with 158 additions and 1 deletions

View File

@@ -44,6 +44,7 @@ from .test_framework_agnostic import GenerationIntegrationTestsMixin
if is_torch_available():
import torch
import torch.nn.functional as F
from transformers import (
AutoModelForCausalLM,
@@ -59,6 +60,7 @@ if is_torch_available():
GPT2Tokenizer,
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
)
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
@@ -3644,6 +3646,139 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
@slow
def test_padding_input_contrastive_search_gpt2(self):
# Load the pre-trained GPT-2 model and tokenizer
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True)
# Set the tokenizer to left-pad the sequences
tokenizer.padding_side = "left"
# Define the PAD token as the EOS token
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id
# Define the input prompt
prompt_text = "The whispered legends of the haunted mansion spoke"
# Tokenize the input prompt
encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True)
input_ids = encoded_prompt.input_ids.to(torch_device)
attention_mask = encoded_prompt.attention_mask.to(torch_device)
# Define the contrastive search params
penalty_alpha = 0.6
top_k = 4
# Define the padding length to add to the input IDs and attention mask
padding_length = 10
# Generate text without padding
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Pad the input IDs and attention mask on the left
padded_input_ids = F.pad(
input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
)
padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0)
# Generate text with padded inputs
outputs_with_padding = model.generate(
input_ids=padded_input_ids,
attention_mask=padded_attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)
# Assert that the generated texts are identical for padded and non-padded inputs
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(
generated_text_with_padding,
'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling '
'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been '
'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea',
)
@slow
def test_padding_input_contrastive_search_t5(self):
# Load the pre-trained T5 model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True)
# Define the input prompt
prompt_text = "translate English to German: I need to finish this task before the end of the day."
# Tokenize the input prompt
encoded_prompt = tokenizer(prompt_text, return_tensors="pt")
input_ids = encoded_prompt.input_ids.to(torch_device)
attention_mask = encoded_prompt.attention_mask.to(torch_device)
# Define the decoder prompt
decoder_prompt_text = "Ich muss diese Aufgabe"
encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt")
decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device)
decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device)
# Define the contrastive search params
penalty_alpha = 0.6
top_k = 4
# Generate text without padding
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Define the padding length to add to the input IDs and attention mask
padding_length = 10
# Pad the decoder input IDs and attention mask on the left
padded_decoder_input_ids = F.pad(
decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
)
padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0)
# Since the decoder_start_token_id is the same as the pad_token_id,
# the last padded token represents the decoder start token.
# Set the attention mask for the decoder_start_token_id to True (1).
padded_decoder_attention_mask[:, padding_length - 1] = 1
# Generate text with padded inputs
outputs_with_padding = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=padded_decoder_input_ids,
decoder_attention_mask=padded_decoder_attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)
# Assert that the generated texts are identical for padded and non-padded inputs
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")
@require_torch
class TokenHealingTestCase(unittest.TestCase):