Generate: text generation pipeline no longer emits max_length warning when it is not set (#23139)
This commit is contained in:
@@ -385,7 +385,6 @@ class FlaxGenerationMixin:
|
|||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
elif generation_config.max_new_tokens is not None:
|
elif generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
|
||||||
if not has_default_max_length:
|
if not has_default_max_length:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
@@ -393,6 +392,7 @@ class FlaxGenerationMixin:
|
|||||||
"Please refer to the documentation for more information. "
|
"Please refer to the documentation for more information. "
|
||||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||||
)
|
)
|
||||||
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
|
|
||||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -858,7 +858,6 @@ class TFGenerationMixin:
|
|||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
elif generation_config.max_new_tokens is not None:
|
elif generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
|
||||||
if not has_default_max_length:
|
if not has_default_max_length:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
@@ -866,6 +865,7 @@ class TFGenerationMixin:
|
|||||||
"Please refer to the documentation for more information. "
|
"Please refer to the documentation for more information. "
|
||||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||||
)
|
)
|
||||||
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
|
|
||||||
# If the input length is a tensor (i.e. dynamic length), skip length checks
|
# If the input length is a tensor (i.e. dynamic length), skip length checks
|
||||||
if not isinstance(input_ids_seq_length, tf.Tensor):
|
if not isinstance(input_ids_seq_length, tf.Tensor):
|
||||||
|
|||||||
@@ -1348,7 +1348,6 @@ class GenerationMixin:
|
|||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
elif generation_config.max_new_tokens is not None:
|
elif generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
|
||||||
if not has_default_max_length:
|
if not has_default_max_length:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
@@ -1356,6 +1355,7 @@ class GenerationMixin:
|
|||||||
"Please refer to the documentation for more information. "
|
"Please refer to the documentation for more information. "
|
||||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||||
)
|
)
|
||||||
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
|
|
||||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import enum
|
import enum
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@@ -105,17 +106,8 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
prefix_inputs = self.tokenizer(
|
prefix_inputs = self.tokenizer(
|
||||||
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
|
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
|
||||||
)
|
)
|
||||||
prefix_length = prefix_inputs["input_ids"].shape[-1]
|
generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1]
|
||||||
|
|
||||||
if "max_new_tokens" in generate_kwargs:
|
|
||||||
pass
|
|
||||||
elif "max_length" in generate_kwargs:
|
|
||||||
generate_kwargs["max_length"] += prefix_length
|
|
||||||
else:
|
|
||||||
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
|
|
||||||
|
|
||||||
if "min_length" in generate_kwargs:
|
|
||||||
generate_kwargs["min_length"] += prefix_length
|
|
||||||
if handle_long_generation is not None:
|
if handle_long_generation is not None:
|
||||||
if handle_long_generation not in {"hole"}:
|
if handle_long_generation not in {"hole"}:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -247,6 +239,26 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
in_b = input_ids.shape[0]
|
in_b = input_ids.shape[0]
|
||||||
prompt_text = model_inputs.pop("prompt_text")
|
prompt_text = model_inputs.pop("prompt_text")
|
||||||
|
|
||||||
|
# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
||||||
|
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
||||||
|
generate_kwargs = copy.deepcopy(generate_kwargs)
|
||||||
|
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
||||||
|
if prefix_length > 0:
|
||||||
|
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
||||||
|
"generation_config" in generate_kwargs
|
||||||
|
and generate_kwargs["generation_config"].max_new_tokens is not None
|
||||||
|
)
|
||||||
|
if not has_max_new_tokens:
|
||||||
|
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
|
||||||
|
generate_kwargs["max_length"] += prefix_length
|
||||||
|
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
||||||
|
"generation_config" in generate_kwargs
|
||||||
|
and generate_kwargs["generation_config"].min_new_tokens is not None
|
||||||
|
)
|
||||||
|
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
||||||
|
generate_kwargs["min_length"] += prefix_length
|
||||||
|
|
||||||
# BS x SL
|
# BS x SL
|
||||||
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
||||||
out_b = generated_sequence.shape[0]
|
out_b = generated_sequence.shape[0]
|
||||||
|
|||||||
@@ -14,8 +14,15 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
|
from transformers import (
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
TextGenerationPipeline,
|
||||||
|
logging,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
CaptureLogger,
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_tf,
|
require_tf,
|
||||||
@@ -323,3 +330,26 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
|
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
|
||||||
pipe("This is a test", do_sample=True, top_p=0.5)
|
pipe("This is a test", do_sample=True, top_p=0.5)
|
||||||
|
|
||||||
|
def test_pipeline_length_setting_warning(self):
|
||||||
|
prompt = """Hello world"""
|
||||||
|
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2")
|
||||||
|
if text_generator.model.framework == "tf":
|
||||||
|
logger = logging.get_logger("transformers.generation.tf_utils")
|
||||||
|
else:
|
||||||
|
logger = logging.get_logger("transformers.generation.utils")
|
||||||
|
logger_msg = "Both `max_new_tokens`" # The beggining of the message to be checked in this test
|
||||||
|
|
||||||
|
# Both are set by the user -> log warning
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
_ = text_generator(prompt, max_length=10, max_new_tokens=1)
|
||||||
|
self.assertIn(logger_msg, cl.out)
|
||||||
|
|
||||||
|
# The user only sets one -> no warning
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
_ = text_generator(prompt, max_new_tokens=1)
|
||||||
|
self.assertNotIn(logger_msg, cl.out)
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
_ = text_generator(prompt, max_length=10)
|
||||||
|
self.assertNotIn(logger_msg, cl.out)
|
||||||
|
|||||||
Reference in New Issue
Block a user