Generate: text generation pipeline no longer emits max_length warning when it is not set (#23139)

This commit is contained in:
Joao Gante
2023-05-04 18:36:23 +01:00
committed by GitHub
parent 516dc6305f
commit b369e507aa
5 changed files with 56 additions and 14 deletions

View File

@@ -1,3 +1,4 @@
import copy
import enum
import warnings
@@ -105,17 +106,8 @@ class TextGenerationPipeline(Pipeline):
prefix_inputs = self.tokenizer(
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
)
prefix_length = prefix_inputs["input_ids"].shape[-1]
generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1]
if "max_new_tokens" in generate_kwargs:
pass
elif "max_length" in generate_kwargs:
generate_kwargs["max_length"] += prefix_length
else:
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
if "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
if handle_long_generation is not None:
if handle_long_generation not in {"hole"}:
raise ValueError(
@@ -247,6 +239,26 @@ class TextGenerationPipeline(Pipeline):
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")
# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
generate_kwargs = copy.deepcopy(generate_kwargs)
prefix_length = generate_kwargs.pop("prefix_length", 0)
if prefix_length > 0:
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].max_new_tokens is not None
)
if not has_max_new_tokens:
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].min_new_tokens is not None
)
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
# BS x SL
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
out_b = generated_sequence.shape[0]