Generate: TF can now generate from embeddings in encoder-decoder models (#21475)
This commit is contained in:
@@ -5,11 +5,13 @@ Framework agnostic tests for generate()-related methods.
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.testing_utils import torch_device
|
||||
|
||||
|
||||
class GenerationIntegrationTestsMixin:
|
||||
# To be populated by the child classes
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForCausalLM": None,
|
||||
"AutoModelForSeq2SeqLM": None,
|
||||
"LogitsProcessorList": None,
|
||||
"MinLengthLogitsProcessor": None,
|
||||
@@ -60,3 +62,91 @@ class GenerationIntegrationTestsMixin:
|
||||
|
||||
bart_model.config.min_length = None
|
||||
bart_model.generate(input_ids, logits_processor=logits_processor)
|
||||
|
||||
def test_max_new_tokens_encoder_decoder(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
bart_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
input_ids = bart_tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
bart_model = bart_model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 29])
|
||||
|
||||
max_new_tokens = 3
|
||||
bart_model.config.max_length = 20
|
||||
bart_model.config.eos_token_id = None
|
||||
|
||||
# Encoder decoder call
|
||||
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
# 1 BOS + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 4])
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||
# 29 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 32])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
article = """Justin Timberlake."""
|
||||
gpt2_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
|
||||
gpt2_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
input_ids = gpt2_tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
gpt2_model = gpt2_model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 9])
|
||||
|
||||
max_new_tokens = 3
|
||||
gpt2_model.config.max_length = 20
|
||||
|
||||
# call < 20
|
||||
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
|
||||
# 9 input_ids + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 12])
|
||||
|
||||
# call > 20
|
||||
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
output_sequences = model.generate(inputs_embeds=inputs_embeds)
|
||||
|
||||
# make sure model generated correctly until `max_length`
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
||||
@@ -135,6 +135,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
||||
if is_tf_available():
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForCausalLM": TFAutoModelForCausalLM,
|
||||
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
|
||||
"LogitsProcessorList": TFLogitsProcessorList,
|
||||
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,
|
||||
|
||||
@@ -40,7 +40,6 @@ if is_torch_available():
|
||||
ImageGPTForCausalImageModeling,
|
||||
Speech2TextForConditionalGeneration,
|
||||
SpeechEncoderDecoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
VisionEncoderDecoderModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
@@ -1792,6 +1791,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
||||
if is_torch_available():
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForCausalLM": AutoModelForCausalLM,
|
||||
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
|
||||
"LogitsProcessorList": LogitsProcessorList,
|
||||
"MinLengthLogitsProcessor": MinLengthLogitsProcessor,
|
||||
@@ -2094,182 +2094,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
output = generator(prompt, stop_sequence=" number")
|
||||
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
|
||||
|
||||
def test_max_new_tokens_encoder_decoder(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 29])
|
||||
|
||||
max_new_tokens = 3
|
||||
bart_model.config.max_length = 20
|
||||
bart_model.config.eos_token_id = None
|
||||
|
||||
# Encoder decoder call
|
||||
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
# 1 BOS + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 4])
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||
# 29 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 32])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
t5_model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5").to(torch_device)
|
||||
input_ids = t5_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 56])
|
||||
|
||||
max_new_tokens = 3
|
||||
t5_model.config.max_length = 20
|
||||
t5_model.config.eos_token_id = None
|
||||
|
||||
# Encoder decoder call
|
||||
outputs = t5_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
|
||||
# 1 BOS + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 4])
|
||||
|
||||
# Decoder only call
|
||||
outputs = t5_model.generate(
|
||||
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
|
||||
)
|
||||
# 56 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 59])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = t5_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 29])
|
||||
|
||||
max_new_tokens = 3
|
||||
bart_model.config.max_length = 20
|
||||
bart_model.config.eos_token_id = None
|
||||
|
||||
# Encoder decoder call
|
||||
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
|
||||
# 1 BOS + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 4])
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(
|
||||
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
|
||||
)
|
||||
# 29 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 32])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
|
||||
article = """Justin Timberlake."""
|
||||
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
|
||||
gptj_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gptj").to(torch_device)
|
||||
input_ids = gptj_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 9])
|
||||
|
||||
max_new_tokens = 3
|
||||
gptj_model.config.max_length = 20
|
||||
|
||||
# call < 20
|
||||
outputs = gptj_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 9 input_ids + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 12])
|
||||
|
||||
# call > 20
|
||||
outputs = gptj_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
|
||||
article = """Justin Timberlake."""
|
||||
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 9])
|
||||
|
||||
max_new_tokens = 3
|
||||
gpt2_model.config.max_length = 20
|
||||
|
||||
# call < 20
|
||||
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 9 input_ids + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 12])
|
||||
|
||||
# call > 20
|
||||
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only(self):
|
||||
article = """Justin Timberlake."""
|
||||
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 9])
|
||||
|
||||
max_new_tokens = 3
|
||||
gpt2_model.config.max_length = 20
|
||||
|
||||
# call < 20
|
||||
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
|
||||
# 9 input_ids + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 12])
|
||||
|
||||
# call > 20
|
||||
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
|
||||
torch_device
|
||||
)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
output_sequences = model.generate(inputs_embeds=inputs_embeds)
|
||||
|
||||
# make sure model generated correctly until `max_length`
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
||||
def test_encoder_decoder_generate_attention_mask(self):
|
||||
articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
Reference in New Issue
Block a user