Generate: TF can now accept custom logits processors (#21454)
This commit is contained in:
@@ -1797,12 +1797,15 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
if is_torch_available():
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
|
||||
"LogitsProcessorList": LogitsProcessorList,
|
||||
"MinLengthLogitsProcessor": MinLengthLogitsProcessor,
|
||||
"create_tensor_fn": torch.tensor,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_diverse_beam_search(self):
|
||||
# PT-only test: TF doesn't have a diverse beam search implementation
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
|
||||
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
|
||||
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
|
||||
@@ -1836,6 +1839,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_greedy(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
@@ -1862,6 +1866,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_sample(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
@@ -1888,6 +1893,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_beam_search(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
@@ -1918,6 +1924,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_group_beam_search(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria & group beam search
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
@@ -1952,6 +1959,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
|
||||
def test_max_length_warning_if_different(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
@@ -2035,6 +2043,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
|
||||
def test_custom_stopping_criteria_overload_error(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
@@ -2048,6 +2057,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
|
||||
|
||||
def test_custom_stopping_criteria(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
@@ -2070,7 +2080,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
|
||||
def test_stop_sequence_stopping_criteria(self):
|
||||
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
prompt = """Hello I believe in"""
|
||||
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
|
||||
output = generator(prompt)
|
||||
@@ -2088,23 +2098,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
output = generator(prompt, stop_sequence=" number")
|
||||
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
|
||||
|
||||
def test_custom_logits_processor(self):
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random", min_length=1).to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0))
|
||||
# it should not be allowed to both define `min_length` via config and `logits_processor` list
|
||||
with self.assertRaises(ValueError):
|
||||
bart_model.generate(input_ids, logits_processor=logits_processor)
|
||||
|
||||
bart_model.config.min_length = None
|
||||
bart_model.generate(input_ids, logits_processor=logits_processor)
|
||||
|
||||
def test_max_new_tokens_encoder_decoder(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
Reference in New Issue
Block a user