⚠️ Raise Exception when trying to generate 0 tokens ⚠️ (#28621)
* change warning to exception * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * validate `max_new_tokens` > 0 in `GenerationConfig` * fix truncation test parameterization in `TextGenerationPipelineTests` --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -373,6 +373,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
# Validation of individual attributes
|
# Validation of individual attributes
|
||||||
if self.early_stopping not in {True, False, "never"}:
|
if self.early_stopping not in {True, False, "never"}:
|
||||||
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
||||||
|
if self.max_new_tokens is not None and self.max_new_tokens <= 0:
|
||||||
|
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
|
||||||
|
|
||||||
# Validation of attribute relations:
|
# Validation of attribute relations:
|
||||||
fix_location = ""
|
fix_location = ""
|
||||||
|
|||||||
@@ -1138,11 +1138,10 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
if input_ids_length >= generation_config.max_length:
|
if input_ids_length >= generation_config.max_length:
|
||||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||||
warnings.warn(
|
raise ValueError(
|
||||||
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
|
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
|
||||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||||
" increasing `max_new_tokens`.",
|
" increasing `max_length` or, better yet, setting `max_new_tokens`."
|
||||||
UserWarning,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Min length warnings due to unfeasible parameter combinations
|
# 2. Min length warnings due to unfeasible parameter combinations
|
||||||
|
|||||||
@@ -93,17 +93,19 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
## -- test tokenizer_kwargs
|
## -- test tokenizer_kwargs
|
||||||
test_str = "testing tokenizer kwargs. using truncation must result in a different generation."
|
test_str = "testing tokenizer kwargs. using truncation must result in a different generation."
|
||||||
|
input_len = len(text_generator.tokenizer(test_str)["input_ids"])
|
||||||
output_str, output_str_with_truncation = (
|
output_str, output_str_with_truncation = (
|
||||||
text_generator(test_str, do_sample=False, return_full_text=False)[0]["generated_text"],
|
text_generator(test_str, do_sample=False, return_full_text=False, min_new_tokens=1)[0]["generated_text"],
|
||||||
text_generator(
|
text_generator(
|
||||||
test_str,
|
test_str,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
return_full_text=False,
|
return_full_text=False,
|
||||||
|
min_new_tokens=1,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=3,
|
max_length=input_len + 1,
|
||||||
)[0]["generated_text"],
|
)[0]["generated_text"],
|
||||||
)
|
)
|
||||||
assert output_str != output_str_with_truncation # results must be different because one hd truncation
|
assert output_str != output_str_with_truncation # results must be different because one had truncation
|
||||||
|
|
||||||
# -- what is the point of this test? padding is hardcoded False in the pipeline anyway
|
# -- what is the point of this test? padding is hardcoded False in the pipeline anyway
|
||||||
text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id
|
text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id
|
||||||
|
|||||||
Reference in New Issue
Block a user