Output dicts support in text generation pipeline (#35092)
* Support for generate_argument: return_dict_in_generate=True, instead of returning a error * fix: call test with return_dict_in_generate=True * fix: Only import torch if it is present * update: Encapsulate output_dict changes * fix: added back original comments --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -3,11 +3,13 @@ import itertools
|
|||||||
import types
|
import types
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
|
from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available
|
||||||
from .base import Pipeline, build_pipeline_init_args
|
from .base import Pipeline, build_pipeline_init_args
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
from .pt_utils import KeyDataset
|
from .pt_utils import KeyDataset
|
||||||
|
|
||||||
@@ -380,13 +382,44 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
if "generation_config" not in generate_kwargs:
|
if "generation_config" not in generate_kwargs:
|
||||||
generate_kwargs["generation_config"] = self.generation_config
|
generate_kwargs["generation_config"] = self.generation_config
|
||||||
|
|
||||||
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
output = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
||||||
|
|
||||||
|
if isinstance(output, ModelOutput):
|
||||||
|
generated_sequence = output.sequences
|
||||||
|
other_outputs = {k: v for k, v in output.items() if k != "sequences"}
|
||||||
|
out_b = generated_sequence.shape[0]
|
||||||
|
|
||||||
|
if self.framework == "pt":
|
||||||
|
for key, value in other_outputs.items():
|
||||||
|
if isinstance(value, torch.Tensor) and value.shape[0] == out_b:
|
||||||
|
other_outputs[key] = value.reshape(in_b, out_b // in_b, *value.shape[1:])
|
||||||
|
if isinstance(value, tuple) and len(value[0]) == out_b:
|
||||||
|
value = torch.stack(value).swapaxes(0, 1)
|
||||||
|
other_outputs[key] = value
|
||||||
|
elif self.framework == "tf":
|
||||||
|
for key, value in other_outputs.items():
|
||||||
|
if isinstance(value, tf.Tensor) and value.shape[0] == out_b:
|
||||||
|
other_outputs[key] = tf.reshape(value, (in_b, out_b // in_b, *value.shape[1:]))
|
||||||
|
if isinstance(value, tuple) and len(value[0]) == out_b:
|
||||||
|
value = tf.stack(value).swapaxes(0, 1)
|
||||||
|
other_outputs[key] = value
|
||||||
|
else:
|
||||||
|
generated_sequence = output
|
||||||
|
other_outputs = {}
|
||||||
|
|
||||||
out_b = generated_sequence.shape[0]
|
out_b = generated_sequence.shape[0]
|
||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||||
elif self.framework == "tf":
|
elif self.framework == "tf":
|
||||||
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
|
||||||
|
model_outputs = {
|
||||||
|
"generated_sequence": generated_sequence,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"prompt_text": prompt_text,
|
||||||
|
}
|
||||||
|
model_outputs.update(other_outputs)
|
||||||
|
return model_outputs
|
||||||
|
|
||||||
def postprocess(
|
def postprocess(
|
||||||
self,
|
self,
|
||||||
@@ -400,7 +433,19 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
prompt_text = model_outputs["prompt_text"]
|
prompt_text = model_outputs["prompt_text"]
|
||||||
generated_sequence = generated_sequence.numpy().tolist()
|
generated_sequence = generated_sequence.numpy().tolist()
|
||||||
records = []
|
records = []
|
||||||
for sequence in generated_sequence:
|
other_outputs = model_outputs.get("additional_outputs", {})
|
||||||
|
splitted_keys = {}
|
||||||
|
if other_outputs:
|
||||||
|
if self.framework == "pt":
|
||||||
|
for k, v in other_outputs.items():
|
||||||
|
if isinstance(v, torch.Tensor) and v.shape[0] == len(generated_sequence):
|
||||||
|
splitted_keys[k] = v.numpy().tolist()
|
||||||
|
elif self.framework == "tf":
|
||||||
|
for k, v in other_outputs.items():
|
||||||
|
if isinstance(v, tf.Tensor) and v.shape[0] == len(generated_sequence):
|
||||||
|
splitted_keys[k] = v.numpy().tolist()
|
||||||
|
|
||||||
|
for idx, sequence in enumerate(generated_sequence):
|
||||||
if return_type == ReturnType.TENSORS:
|
if return_type == ReturnType.TENSORS:
|
||||||
record = {"generated_token_ids": sequence}
|
record = {"generated_token_ids": sequence}
|
||||||
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
||||||
@@ -444,6 +489,8 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
# When we're not starting from a prefill, the output is a new assistant message
|
# When we're not starting from a prefill, the output is a new assistant message
|
||||||
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
|
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
|
||||||
record = {"generated_text": all_text}
|
record = {"generated_text": all_text}
|
||||||
|
for key, values in splitted_keys.items():
|
||||||
|
record[key] = values[idx]
|
||||||
records.append(record)
|
records.append(record)
|
||||||
|
|
||||||
return records
|
return records
|
||||||
|
|||||||
@@ -653,6 +653,31 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
_ = text_generator(prompt, max_length=10)
|
_ = text_generator(prompt, max_length=10)
|
||||||
self.assertNotIn(logger_msg, cl.out)
|
self.assertNotIn(logger_msg, cl.out)
|
||||||
|
|
||||||
|
def test_return_dict_in_generate(self):
|
||||||
|
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2", max_new_tokens=16)
|
||||||
|
out = text_generator(
|
||||||
|
["This is great !", "Something else"], return_dict_in_generate=True, output_logits=True, output_scores=True
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
out,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": ANY(str),
|
||||||
|
"logits": ANY(list),
|
||||||
|
"scores": ANY(list),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": ANY(str),
|
||||||
|
"logits": ANY(list),
|
||||||
|
"scores": ANY(list),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pipeline_assisted_generation(self):
|
def test_pipeline_assisted_generation(self):
|
||||||
"""Tests that we can run assisted generation in the pipeline"""
|
"""Tests that we can run assisted generation in the pipeline"""
|
||||||
|
|||||||
Reference in New Issue
Block a user