From 3f9cb335047315edfd4b6ad666ef148e98cc4850 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 16 Aug 2023 15:30:54 +0100 Subject: [PATCH] Generate: fix default max length warning (#25539) --- src/transformers/generation/flax_utils.py | 2 +- src/transformers/generation/tf_utils.py | 2 +- src/transformers/generation/utils.py | 2 +- .../models/musicgen/modeling_musicgen.py | 2 +- tests/generation/test_utils.py | 26 +++++++++++++++++++ 5 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 228bfb4a2e..284e0f51cd 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -377,7 +377,7 @@ class FlaxGenerationMixin: # Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20: + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: # 20 is the default max_length of the generation config warnings.warn( f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 648ec710cf..df392cf5ca 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -829,7 +829,7 @@ class TFGenerationMixin: # 7. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = shape_list(input_ids)[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20: + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: # 20 is the default max_length of the generation config warnings.warn( f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0a32785ef6..404943da05 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1249,7 +1249,7 @@ class GenerationMixin: """Performs validation related to the resulting generated length""" # 1. Max length warnings related to poor parameterization - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20: + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: # 20 is the default max_length of the generation config warnings.warn( f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the" diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index a656839936..314ea051e9 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1300,7 +1300,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): # 5. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20: + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: logger.warning( f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.", diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e6faf5babd..ea9ab3c753 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -16,6 +16,7 @@ import inspect import unittest +import warnings import numpy as np @@ -2844,3 +2845,28 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi with self.assertRaises(TypeError): # FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo" bart_model.generate(input_ids, foo="bar") + + def test_default_max_length_warning(self): + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Default generation config value of 20 -> emits warning + with self.assertWarns(UserWarning): + model.generate(input_ids) + + # Explicitly setting max_length to 20 -> no warning + with warnings.catch_warnings(record=True) as warning_list: + model.generate(input_ids, max_length=20) + self.assertEqual(len(warning_list), 0) + + # Generation config max_length != 20 -> no warning + with warnings.catch_warnings(record=True) as warning_list: + model.generation_config.max_length = 10 + model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence + model.generate(input_ids) + self.assertEqual(len(warning_list), 0)