Generate: Relaxed max_length and max_new_tokens coexistence (#21347)
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -2178,10 +2178,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
||||
with self.assertRaises(ValueError):
|
||||
bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
@@ -2212,12 +2208,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
||||
with self.assertRaises(ValueError):
|
||||
t5_model.generate(
|
||||
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
|
||||
)
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
@@ -2250,12 +2240,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
||||
with self.assertRaises(ValueError):
|
||||
bart_model.generate(
|
||||
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
|
||||
)
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
|
||||
article = """Justin Timberlake."""
|
||||
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
|
||||
@@ -2279,10 +2263,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
||||
with self.assertRaises(ValueError):
|
||||
gptj_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
|
||||
article = """Justin Timberlake."""
|
||||
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
@@ -2306,10 +2286,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
||||
with self.assertRaises(ValueError):
|
||||
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
def test_max_new_tokens_decoder_only(self):
|
||||
article = """Justin Timberlake."""
|
||||
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
@@ -2333,10 +2309,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
||||
with self.assertRaises(ValueError):
|
||||
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)
|
||||
|
||||
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
Reference in New Issue
Block a user