[Generation] Allow inputs_embeds as an input (#14443)
* up * finalize * finalize * finish * Update src/transformers/generation_utils.py * apply feedback
This commit is contained in:
committed by
GitHub
parent
0490b98877
commit
f25a9332e8
@@ -1388,15 +1388,17 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_max_length_backward_compat_greedy(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
max_length = 20
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
input_ids.shape[0],
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1412,15 +1414,17 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_max_length_backward_compat_sample(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
max_length = 20
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
input_ids.shape[0],
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1436,8 +1440,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_max_length_backward_compat_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
@@ -1447,7 +1453,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
input_ids.shape[0],
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1464,8 +1470,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_max_length_backward_compat_group_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
@@ -1477,7 +1485,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
input_ids = input_ids.expand(6, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
input_ids.shape[0],
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1496,8 +1504,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_max_length_warning_if_different(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
@@ -1513,7 +1523,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
input_ids = input_ids.expand(6, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
input_ids.shape[0],
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
@@ -1577,8 +1587,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_beam_search_warning_if_max_length_is_passed(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
batch_size = 1
|
||||
num_beams = 3
|
||||
@@ -1587,6 +1599,9 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
input_ids = input_ids.expand(num_beams, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
|
||||
# pretend decoder_input_ids correspond to first encoder input id
|
||||
decoder_input_ids = input_ids[:, :1]
|
||||
|
||||
stopping_criteria_max_length = 18
|
||||
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
|
||||
|
||||
@@ -1599,7 +1614,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
generated_ids = bart_model.beam_search(
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
beam_scorer=beam_scorer,
|
||||
@@ -1613,7 +1628,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
generated_ids_no_max_len = bart_model.beam_search(
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
beam_scorer=beam_scorer_no_max_len,
|
||||
@@ -1625,14 +1640,17 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_max_new_tokens_encoder_decoder(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 15])
|
||||
self.assertEqual(list(input_ids.shape), [1, 29])
|
||||
|
||||
max_new_tokens = 3
|
||||
bart_model.config.max_length = 20
|
||||
bart_model.config.eos_token_id = None
|
||||
|
||||
# Encoder decoder call
|
||||
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
@@ -1641,8 +1659,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||
# 15 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 18])
|
||||
# 29 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 32])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
@@ -1680,3 +1698,29 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# max_new_tokens and max_length serve the same purpose and should not be used together.
|
||||
with self.assertWarns(UserWarning):
|
||||
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
|
||||
|
||||
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
|
||||
torch_device
|
||||
)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
output_sequences = model.generate(inputs_embeds=inputs_embeds)
|
||||
|
||||
# make sure model generated correctly until `max_length`
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
||||
def test_decoder_generate_with_inputs_embeds(self):
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=5).to(torch_device)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
# cannot generate from `inputs_embeds` for decoder only
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(inputs_embeds=inputs_embeds)
|
||||
|
||||
Reference in New Issue
Block a user