From 42b60f8b02941b0c40c42e150a101eb372c3856e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 30 Jan 2023 17:53:54 +0000 Subject: [PATCH] Generate: Relaxed `max_length` and `max_new_tokens` coexistence (#21347) Co-authored-by: Patrick von Platen --- .../generation/configuration_utils.py | 6 ++-- src/transformers/generation/flax_utils.py | 22 +++++++-------- src/transformers/generation/tf_utils.py | 22 +++++++-------- src/transformers/generation/utils.py | 22 +++++++-------- tests/generation/test_utils.py | 28 ------------------- 5 files changed, 35 insertions(+), 65 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index a42718c5c7..a869d49ccf 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -63,14 +63,12 @@ class GenerationConfig(PushToHubMixin): max_length (`int`, *optional*, defaults to 20): The maximum length the generated tokens can have. Corresponds to the length of the input prompt + - `max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in the - prompt. + `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. max_new_tokens (`int`, *optional*): The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. min_length (`int`, *optional*, defaults to 0): The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + - `min_new_tokens`. In general, prefer the use of `min_new_tokens`, which ignores the number of tokens in the - prompt. + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set. min_new_tokens (`int`, *optional*): The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. early_stopping (`bool`, *optional*, defaults to `False`): diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 598f50b64b..a327621c3c 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -318,21 +318,21 @@ class FlaxGenerationMixin: has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to" - f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" - " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) - elif has_default_max_length and generation_config.max_new_tokens is not None: + elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - elif not has_default_max_length and generation_config.max_new_tokens is not None: - raise ValueError( - "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" - " limit to the generated output length. Remove one of those arguments. Please refer to the" - " documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index d3d66b599a..deeefdcec4 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -700,21 +700,21 @@ class TFGenerationMixin: has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to" - f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" - " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) - elif has_default_max_length and generation_config.max_new_tokens is not None: + elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - elif not has_default_max_length and generation_config.max_new_tokens is not None: - raise ValueError( - "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" - " limit to the generated output length. Remove one of those arguments. Please refer to the" - " documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 39c507b6d2..d18a82d17f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1274,21 +1274,21 @@ class GenerationMixin: has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" - f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" - " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) - elif has_default_max_length and generation_config.max_new_tokens is not None: + elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - elif not has_default_max_length and generation_config.max_new_tokens is not None: - raise ValueError( - "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" - " limit to the generated output length. Remove one of those arguments. Please refer to the" - " documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3339b60091..5a5e578ee4 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2178,10 +2178,6 @@ class GenerationIntegrationTests(unittest.TestCase): # 1 BOS + 20 + 3 new tokens self.assertEqual(list(outputs.shape), [1, 24]) - # max_new_tokens and max_length serve the same purpose and must not be used together. - with self.assertRaises(ValueError): - bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) - def test_max_new_tokens_decoder_only_contrastive_search_t5(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") @@ -2212,12 +2208,6 @@ class GenerationIntegrationTests(unittest.TestCase): # 1 BOS + 20 + 3 new tokens self.assertEqual(list(outputs.shape), [1, 24]) - # max_new_tokens and max_length serve the same purpose and must not be used together. - with self.assertRaises(ValueError): - t5_model.generate( - decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4 - ) - def test_max_new_tokens_decoder_only_contrastive_search_bart(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") @@ -2250,12 +2240,6 @@ class GenerationIntegrationTests(unittest.TestCase): # 1 BOS + 20 + 3 new tokens self.assertEqual(list(outputs.shape), [1, 24]) - # max_new_tokens and max_length serve the same purpose and must not be used together. - with self.assertRaises(ValueError): - bart_model.generate( - decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4 - ) - def test_max_new_tokens_decoder_only_contrastive_search_gptj(self): article = """Justin Timberlake.""" gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj") @@ -2279,10 +2263,6 @@ class GenerationIntegrationTests(unittest.TestCase): # 1 BOS token + 23 new tokens self.assertEqual(list(outputs.shape), [1, 24]) - # max_new_tokens and max_length serve the same purpose and must not be used together. - with self.assertRaises(ValueError): - gptj_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4) - def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self): article = """Justin Timberlake.""" gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") @@ -2306,10 +2286,6 @@ class GenerationIntegrationTests(unittest.TestCase): # 1 BOS token + 23 new tokens self.assertEqual(list(outputs.shape), [1, 24]) - # max_new_tokens and max_length serve the same purpose and must not be used together. - with self.assertRaises(ValueError): - gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4) - def test_max_new_tokens_decoder_only(self): article = """Justin Timberlake.""" gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") @@ -2333,10 +2309,6 @@ class GenerationIntegrationTests(unittest.TestCase): # 1 BOS token + 23 new tokens self.assertEqual(list(outputs.shape), [1, 24]) - # max_new_tokens and max_length serve the same purpose and must not be used together. - with self.assertRaises(ValueError): - gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20) - def test_encoder_decoder_generate_with_inputs_embeds(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")