Generate: TF can now accept custom logits processors (#21454)

This commit is contained in:
Joao Gante
2023-02-06 15:44:47 +00:00
committed by GitHub
parent e215e6ded2
commit 4943331015
5 changed files with 81 additions and 19 deletions

View File

@@ -12,6 +12,8 @@ class GenerationIntegrationTestsMixin:
# To be populated by the child classes
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": None,
"LogitsProcessorList": None,
"MinLengthLogitsProcessor": None,
"create_tensor_fn": None,
"return_tensors": None,
}
@@ -39,3 +41,23 @@ class GenerationIntegrationTestsMixin:
# however, valid model_kwargs are accepted
valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))}
model.generate(input_ids, **valid_model_kwargs)
def test_custom_logits_processor(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
logits_processor_list_cls = self.framework_dependent_parameters["LogitsProcessorList"]
min_length_logits_processor_cls = self.framework_dependent_parameters["MinLengthLogitsProcessor"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", min_length=1)
input_ids = bart_tokenizer(article, return_tensors=return_tensors).input_ids
logits_processor = logits_processor_list_cls()
logits_processor.append(min_length_logits_processor_cls(min_length=10, eos_token_id=0))
# it should not be allowed to both define `min_length` via config and `logits_processor` list
with self.assertRaises(ValueError):
bart_model.generate(input_ids, logits_processor=logits_processor)
bart_model.config.min_length = None
bart_model.generate(input_ids, logits_processor=logits_processor)

View File

@@ -25,7 +25,13 @@ from .test_framework_agnostic import GenerationIntegrationTestsMixin
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering
from transformers import (
TFAutoModelForCausalLM,
TFAutoModelForSeq2SeqLM,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
tf_top_k_top_p_filtering,
)
@require_tf
@@ -132,6 +138,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
if is_tf_available():
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
"LogitsProcessorList": TFLogitsProcessorList,
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,
"create_tensor_fn": tf.convert_to_tensor,
"return_tensors": "tf",
}

View File

@@ -1797,12 +1797,15 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
if is_torch_available():
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
"LogitsProcessorList": LogitsProcessorList,
"MinLengthLogitsProcessor": MinLengthLogitsProcessor,
"create_tensor_fn": torch.tensor,
"return_tensors": "pt",
}
@slow
def test_diverse_beam_search(self):
# PT-only test: TF doesn't have a diverse beam search implementation
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
@@ -1836,6 +1839,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_greedy(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
@@ -1862,6 +1866,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_sample(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
@@ -1888,6 +1893,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_beam_search(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
@@ -1918,6 +1924,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_group_beam_search(self):
# PT-only test: TF doesn't have StoppingCriteria & group beam search
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
@@ -1952,6 +1959,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_warning_if_different(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
@@ -2035,6 +2043,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_custom_stopping_criteria_overload_error(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
@@ -2048,6 +2057,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
def test_custom_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
@@ -2070,7 +2080,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_stop_sequence_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
prompt = """Hello I believe in"""
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
output = generator(prompt)
@@ -2088,23 +2098,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
output = generator(prompt, stop_sequence=" number")
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
def test_custom_logits_processor(self):
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random", min_length=1).to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
logits_processor = LogitsProcessorList()
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0))
# it should not be allowed to both define `min_length` via config and `logits_processor` list
with self.assertRaises(ValueError):
bart_model.generate(input_ids, logits_processor=logits_processor)
bart_model.config.min_length = None
bart_model.generate(input_ids, logits_processor=logits_processor)
def test_max_new_tokens_encoder_decoder(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")