From abf8f54a019ce14b5eaffa68c6dd883be13fe66e Mon Sep 17 00:00:00 2001 From: Daniel Korat Date: Wed, 7 Feb 2024 14:42:01 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9A=A0=EF=B8=8F=20Raise=20`Exception`=20when?= =?UTF-8?q?=20trying=20to=20generate=200=20tokens=20=E2=9A=A0=EF=B8=8F=20(?= =?UTF-8?q?#28621)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * change warning to exception * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante * validate `max_new_tokens` > 0 in `GenerationConfig` * fix truncation test parameterization in `TextGenerationPipelineTests` --------- Co-authored-by: Joao Gante --- src/transformers/generation/configuration_utils.py | 2 ++ src/transformers/generation/utils.py | 5 ++--- tests/pipelines/test_pipelines_text_generation.py | 8 +++++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 17b8875a40..25abcc67e9 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -373,6 +373,8 @@ class GenerationConfig(PushToHubMixin): # Validation of individual attributes if self.early_stopping not in {True, False, "never"}: raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.") + if self.max_new_tokens is not None and self.max_new_tokens <= 0: + raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.") # Validation of attribute relations: fix_location = "" diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 622d673177..0b8102c353 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1138,11 +1138,10 @@ class GenerationMixin: ) if input_ids_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - warnings.warn( + raise ValueError( f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`.", - UserWarning, + " increasing `max_length` or, better yet, setting `max_new_tokens`." ) # 2. Min length warnings due to unfeasible parameter combinations diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index bf4c1e9f9d..0500e3b035 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -93,17 +93,19 @@ class TextGenerationPipelineTests(unittest.TestCase): ## -- test tokenizer_kwargs test_str = "testing tokenizer kwargs. using truncation must result in a different generation." + input_len = len(text_generator.tokenizer(test_str)["input_ids"]) output_str, output_str_with_truncation = ( - text_generator(test_str, do_sample=False, return_full_text=False)[0]["generated_text"], + text_generator(test_str, do_sample=False, return_full_text=False, min_new_tokens=1)[0]["generated_text"], text_generator( test_str, do_sample=False, return_full_text=False, + min_new_tokens=1, truncation=True, - max_length=3, + max_length=input_len + 1, )[0]["generated_text"], ) - assert output_str != output_str_with_truncation # results must be different because one hd truncation + assert output_str != output_str_with_truncation # results must be different because one had truncation # -- what is the point of this test? padding is hardcoded False in the pipeline anyway text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id