Generate: text generation pipeline no longer emits max_length warning when it is not set (#23139)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import enum
|
||||
import warnings
|
||||
|
||||
@@ -105,17 +106,8 @@ class TextGenerationPipeline(Pipeline):
|
||||
prefix_inputs = self.tokenizer(
|
||||
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
|
||||
)
|
||||
prefix_length = prefix_inputs["input_ids"].shape[-1]
|
||||
generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1]
|
||||
|
||||
if "max_new_tokens" in generate_kwargs:
|
||||
pass
|
||||
elif "max_length" in generate_kwargs:
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
else:
|
||||
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
|
||||
|
||||
if "min_length" in generate_kwargs:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
if handle_long_generation is not None:
|
||||
if handle_long_generation not in {"hole"}:
|
||||
raise ValueError(
|
||||
@@ -247,6 +239,26 @@ class TextGenerationPipeline(Pipeline):
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
prompt_text = model_inputs.pop("prompt_text")
|
||||
|
||||
# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
||||
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
||||
generate_kwargs = copy.deepcopy(generate_kwargs)
|
||||
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
||||
if prefix_length > 0:
|
||||
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].max_new_tokens is not None
|
||||
)
|
||||
if not has_max_new_tokens:
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].min_new_tokens is not None
|
||||
)
|
||||
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
|
||||
# BS x SL
|
||||
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
||||
out_b = generated_sequence.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user