Generate: TF can now accept custom logits processors (#21454)

This commit is contained in:
Joao Gante
2023-02-06 15:44:47 +00:00
committed by GitHub
parent e215e6ded2
commit 4943331015
5 changed files with 81 additions and 19 deletions

View File

@@ -1797,12 +1797,15 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
if is_torch_available():
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
"LogitsProcessorList": LogitsProcessorList,
"MinLengthLogitsProcessor": MinLengthLogitsProcessor,
"create_tensor_fn": torch.tensor,
"return_tensors": "pt",
}
@slow
def test_diverse_beam_search(self):
# PT-only test: TF doesn't have a diverse beam search implementation
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
@@ -1836,6 +1839,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_greedy(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
@@ -1862,6 +1866,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_sample(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
@@ -1888,6 +1893,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_beam_search(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
@@ -1918,6 +1924,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_group_beam_search(self):
# PT-only test: TF doesn't have StoppingCriteria & group beam search
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
@@ -1952,6 +1959,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_warning_if_different(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
@@ -2035,6 +2043,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_custom_stopping_criteria_overload_error(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
@@ -2048,6 +2057,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
def test_custom_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
@@ -2070,7 +2080,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_stop_sequence_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
prompt = """Hello I believe in"""
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
output = generator(prompt)
@@ -2088,23 +2098,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
output = generator(prompt, stop_sequence=" number")
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
def test_custom_logits_processor(self):
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random", min_length=1).to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
logits_processor = LogitsProcessorList()
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0))
# it should not be allowed to both define `min_length` via config and `logits_processor` list
with self.assertRaises(ValueError):
bart_model.generate(input_ids, logits_processor=logits_processor)
bart_model.config.min_length = None
bart_model.generate(input_ids, logits_processor=logits_processor)
def test_max_new_tokens_encoder_decoder(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")