Generate: Relaxed max_length and max_new_tokens coexistence (#21347)
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -63,14 +63,12 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
|
|
||||||
max_length (`int`, *optional*, defaults to 20):
|
max_length (`int`, *optional*, defaults to 20):
|
||||||
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
|
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
|
||||||
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in the
|
`max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
|
||||||
prompt.
|
|
||||||
max_new_tokens (`int`, *optional*):
|
max_new_tokens (`int`, *optional*):
|
||||||
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||||
min_length (`int`, *optional*, defaults to 0):
|
min_length (`int`, *optional*, defaults to 0):
|
||||||
The minimum length of the sequence to be generated. Corresponds to the length of the input prompt +
|
The minimum length of the sequence to be generated. Corresponds to the length of the input prompt +
|
||||||
`min_new_tokens`. In general, prefer the use of `min_new_tokens`, which ignores the number of tokens in the
|
`min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
|
||||||
prompt.
|
|
||||||
min_new_tokens (`int`, *optional*):
|
min_new_tokens (`int`, *optional*):
|
||||||
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||||
early_stopping (`bool`, *optional*, defaults to `False`):
|
early_stopping (`bool`, *optional*, defaults to `False`):
|
||||||
|
|||||||
@@ -318,20 +318,20 @@ class FlaxGenerationMixin:
|
|||||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to"
|
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
||||||
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
|
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
||||||
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
|
|
||||||
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
elif has_default_max_length and generation_config.max_new_tokens is not None:
|
elif generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
elif not has_default_max_length and generation_config.max_new_tokens is not None:
|
if not has_default_max_length:
|
||||||
raise ValueError(
|
logger.warn(
|
||||||
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
" limit to the generated output length. Remove one of those arguments. Please refer to the"
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||||
" documentation for more information. "
|
"Please refer to the documentation for more information. "
|
||||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
||||||
|
UserWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||||
|
|||||||
@@ -700,20 +700,20 @@ class TFGenerationMixin:
|
|||||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to"
|
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
||||||
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
|
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
||||||
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
|
|
||||||
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
elif has_default_max_length and generation_config.max_new_tokens is not None:
|
elif generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
elif not has_default_max_length and generation_config.max_new_tokens is not None:
|
if not has_default_max_length:
|
||||||
raise ValueError(
|
logger.warn(
|
||||||
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
" limit to the generated output length. Remove one of those arguments. Please refer to the"
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||||
" documentation for more information. "
|
"Please refer to the documentation for more information. "
|
||||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
||||||
|
UserWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||||
|
|||||||
@@ -1274,20 +1274,20 @@ class GenerationMixin:
|
|||||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
|
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
||||||
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
|
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
||||||
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
|
|
||||||
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
elif has_default_max_length and generation_config.max_new_tokens is not None:
|
elif generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
elif not has_default_max_length and generation_config.max_new_tokens is not None:
|
if not has_default_max_length:
|
||||||
raise ValueError(
|
logger.warn(
|
||||||
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
" limit to the generated output length. Remove one of those arguments. Please refer to the"
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||||
" documentation for more information. "
|
"Please refer to the documentation for more information. "
|
||||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
||||||
|
UserWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||||
|
|||||||
@@ -2178,10 +2178,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
# 1 BOS + 20 + 3 new tokens
|
# 1 BOS + 20 + 3 new tokens
|
||||||
self.assertEqual(list(outputs.shape), [1, 24])
|
self.assertEqual(list(outputs.shape), [1, 24])
|
||||||
|
|
||||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
|
|
||||||
|
|
||||||
def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
|
def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||||
@@ -2212,12 +2208,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
# 1 BOS + 20 + 3 new tokens
|
# 1 BOS + 20 + 3 new tokens
|
||||||
self.assertEqual(list(outputs.shape), [1, 24])
|
self.assertEqual(list(outputs.shape), [1, 24])
|
||||||
|
|
||||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
t5_model.generate(
|
|
||||||
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
|
def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
@@ -2250,12 +2240,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
# 1 BOS + 20 + 3 new tokens
|
# 1 BOS + 20 + 3 new tokens
|
||||||
self.assertEqual(list(outputs.shape), [1, 24])
|
self.assertEqual(list(outputs.shape), [1, 24])
|
||||||
|
|
||||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
bart_model.generate(
|
|
||||||
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
|
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
|
||||||
article = """Justin Timberlake."""
|
article = """Justin Timberlake."""
|
||||||
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
|
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
|
||||||
@@ -2279,10 +2263,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
# 1 BOS token + 23 new tokens
|
# 1 BOS token + 23 new tokens
|
||||||
self.assertEqual(list(outputs.shape), [1, 24])
|
self.assertEqual(list(outputs.shape), [1, 24])
|
||||||
|
|
||||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
gptj_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
|
|
||||||
|
|
||||||
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
|
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
|
||||||
article = """Justin Timberlake."""
|
article = """Justin Timberlake."""
|
||||||
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
@@ -2306,10 +2286,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
# 1 BOS token + 23 new tokens
|
# 1 BOS token + 23 new tokens
|
||||||
self.assertEqual(list(outputs.shape), [1, 24])
|
self.assertEqual(list(outputs.shape), [1, 24])
|
||||||
|
|
||||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
|
|
||||||
|
|
||||||
def test_max_new_tokens_decoder_only(self):
|
def test_max_new_tokens_decoder_only(self):
|
||||||
article = """Justin Timberlake."""
|
article = """Justin Timberlake."""
|
||||||
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
@@ -2333,10 +2309,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
# 1 BOS token + 23 new tokens
|
# 1 BOS token + 23 new tokens
|
||||||
self.assertEqual(list(outputs.shape), [1, 24])
|
self.assertEqual(list(outputs.shape), [1, 24])
|
||||||
|
|
||||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)
|
|
||||||
|
|
||||||
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
||||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
|||||||
Reference in New Issue
Block a user