Fixing support batch_size and num_return_Sequences in text-generation pipeline (#15318)
* Fixing support `batch_size` and `num_return_Sequences` in `text-generation` pipeline And `text2text-generation` too. The bug was caused by the batch_size containing both the incoming batch **and** the generated `num_sequences`. The fix simply consists into splitting both of these again into different dimensions. * TF support. * Odd backward compatibility script in the way.
This commit is contained in:
@@ -136,7 +136,11 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
result = super().__call__(*args, **kwargs)
|
result = super().__call__(*args, **kwargs)
|
||||||
if isinstance(args[0], list) and all(isinstance(el, str) for el in args[0]):
|
if (
|
||||||
|
isinstance(args[0], list)
|
||||||
|
and all(isinstance(el, str) for el in args[0])
|
||||||
|
and all(len(res) == 1 for res in result)
|
||||||
|
):
|
||||||
return [res[0] for res in result]
|
return [res[0] for res in result]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -146,19 +150,24 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
|
|
||||||
def _forward(self, model_inputs, **generate_kwargs):
|
def _forward(self, model_inputs, **generate_kwargs):
|
||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
input_length = model_inputs["input_ids"].shape[-1]
|
in_b, input_length = model_inputs["input_ids"].shape
|
||||||
elif self.framework == "tf":
|
elif self.framework == "tf":
|
||||||
input_length = tf.shape(model_inputs["input_ids"])[-1].numpy()
|
in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
|
||||||
|
|
||||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
|
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
|
||||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
|
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||||
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
|
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
|
||||||
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
|
||||||
|
out_b = output_ids.shape[0]
|
||||||
|
if self.framework == "pt":
|
||||||
|
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
|
||||||
|
elif self.framework == "tf":
|
||||||
|
output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:]))
|
||||||
return {"output_ids": output_ids}
|
return {"output_ids": output_ids}
|
||||||
|
|
||||||
def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
|
def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
|
||||||
records = []
|
records = []
|
||||||
for output_ids in model_outputs["output_ids"]:
|
for output_ids in model_outputs["output_ids"][0]:
|
||||||
if return_type == ReturnType.TENSORS:
|
if return_type == ReturnType.TENSORS:
|
||||||
record = {f"{self.return_name}_token_ids": model_outputs}
|
record = {f"{self.return_name}_token_ids": model_outputs}
|
||||||
elif return_type == ReturnType.TEXT:
|
elif return_type == ReturnType.TEXT:
|
||||||
|
|||||||
@@ -2,10 +2,14 @@ import enum
|
|||||||
|
|
||||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
|
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
|
||||||
|
|
||||||
from ..file_utils import add_end_docstrings
|
from ..file_utils import add_end_docstrings, is_tf_available
|
||||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
class ReturnType(enum.Enum):
|
class ReturnType(enum.Enum):
|
||||||
TENSORS = 0
|
TENSORS = 0
|
||||||
NEW_TEXT = 1
|
NEW_TEXT = 1
|
||||||
@@ -202,23 +206,29 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
# Allow empty prompts
|
# Allow empty prompts
|
||||||
if input_ids.shape[1] == 0:
|
if input_ids.shape[1] == 0:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
in_b = 1
|
||||||
|
else:
|
||||||
|
in_b = input_ids.shape[0]
|
||||||
prompt_text = model_inputs.pop("prompt_text")
|
prompt_text = model_inputs.pop("prompt_text")
|
||||||
generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
|
generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
|
||||||
|
out_b = generated_sequence.shape[0]
|
||||||
|
if self.framework == "pt":
|
||||||
|
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||||
|
elif self.framework == "tf":
|
||||||
|
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
||||||
|
|
||||||
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
|
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
|
||||||
generated_sequence = model_outputs["generated_sequence"]
|
generated_sequence = model_outputs["generated_sequence"][0]
|
||||||
input_ids = model_outputs["input_ids"]
|
input_ids = model_outputs["input_ids"]
|
||||||
prompt_text = model_outputs["prompt_text"]
|
prompt_text = model_outputs["prompt_text"]
|
||||||
if self.framework == "pt" and generated_sequence is not None:
|
|
||||||
generated_sequence = generated_sequence.cpu()
|
|
||||||
generated_sequence = generated_sequence.numpy().tolist()
|
generated_sequence = generated_sequence.numpy().tolist()
|
||||||
if return_type == ReturnType.TENSORS:
|
records = []
|
||||||
record = {"generated_token_ids": generated_sequence}
|
for sequence in generated_sequence:
|
||||||
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
if return_type == ReturnType.TENSORS:
|
||||||
# Decode text
|
record = {"generated_token_ids": generated_sequence}
|
||||||
record = []
|
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
||||||
for sequence in generated_sequence:
|
# Decode text
|
||||||
text = self.tokenizer.decode(
|
text = self.tokenizer.decode(
|
||||||
sequence,
|
sequence,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
@@ -242,7 +252,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
all_text = text[prompt_length:]
|
all_text = text[prompt_length:]
|
||||||
|
|
||||||
item = {"generated_text": all_text}
|
record = {"generated_text": all_text}
|
||||||
record.append(item)
|
records.append(record)
|
||||||
|
|
||||||
return record
|
return records
|
||||||
|
|||||||
@@ -40,6 +40,26 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
# These are encoder decoder, they don't just append to incoming string
|
# These are encoder decoder, they don't just append to incoming string
|
||||||
self.assertFalse(outputs[0]["generated_text"].startswith("Something there"))
|
self.assertFalse(outputs[0]["generated_text"].startswith("Something there"))
|
||||||
|
|
||||||
|
outputs = generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||||
|
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = generator(
|
||||||
|
["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||||
|
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
generator(4)
|
generator(4)
|
||||||
|
|
||||||
|
|||||||
@@ -113,6 +113,27 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
|||||||
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
|
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
|
||||||
self.assertTrue(outputs[0]["generated_text"].startswith("This is a test"))
|
self.assertTrue(outputs[0]["generated_text"].startswith("This is a test"))
|
||||||
|
|
||||||
|
outputs = text_generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||||
|
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_generator.tokenizer.pad_token is not None:
|
||||||
|
outputs = text_generator(
|
||||||
|
["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||||
|
[{"generated_text": ANY(str)}, {"generated_text": ANY(str)}],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Empty prompt is slighly special
|
# Empty prompt is slighly special
|
||||||
# it requires BOS token to exist.
|
# it requires BOS token to exist.
|
||||||
# Special case for Pegasus which will always append EOS so will
|
# Special case for Pegasus which will always append EOS so will
|
||||||
|
|||||||
Reference in New Issue
Block a user