From 06107541d3df39e3d9685e64c3b942d9aed06d75 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 28 Jan 2022 12:15:30 +0100 Subject: [PATCH] Fixing support `batch_size` and `num_return_Sequences` in `text-generation` pipeline (#15318) * Fixing support `batch_size` and `num_return_Sequences` in `text-generation` pipeline And `text2text-generation` too. The bug was caused by the batch_size containing both the incoming batch **and** the generated `num_sequences`. The fix simply consists into splitting both of these again into different dimensions. * TF support. * Odd backward compatibility script in the way. --- .../pipelines/text2text_generation.py | 17 ++++++--- src/transformers/pipelines/text_generation.py | 36 ++++++++++++------- tests/test_pipelines_text2text_generation.py | 20 +++++++++++ tests/test_pipelines_text_generation.py | 21 +++++++++++ 4 files changed, 77 insertions(+), 17 deletions(-) diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index 924104d5bb..6b194a5cf9 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -136,7 +136,11 @@ class Text2TextGenerationPipeline(Pipeline): """ result = super().__call__(*args, **kwargs) - if isinstance(args[0], list) and all(isinstance(el, str) for el in args[0]): + if ( + isinstance(args[0], list) + and all(isinstance(el, str) for el in args[0]) + and all(len(res) == 1 for res in result) + ): return [res[0] for res in result] return result @@ -146,19 +150,24 @@ class Text2TextGenerationPipeline(Pipeline): def _forward(self, model_inputs, **generate_kwargs): if self.framework == "pt": - input_length = model_inputs["input_ids"].shape[-1] + in_b, input_length = model_inputs["input_ids"].shape elif self.framework == "tf": - input_length = tf.shape(model_inputs["input_ids"])[-1].numpy() + in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy() generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length) generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length) self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) output_ids = self.model.generate(**model_inputs, **generate_kwargs) + out_b = output_ids.shape[0] + if self.framework == "pt": + output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) + elif self.framework == "tf": + output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:])) return {"output_ids": output_ids} def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False): records = [] - for output_ids in model_outputs["output_ids"]: + for output_ids in model_outputs["output_ids"][0]: if return_type == ReturnType.TENSORS: record = {f"{self.return_name}_token_ids": model_outputs} elif return_type == ReturnType.TEXT: diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 885c1f8da7..0c1d0093b7 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -2,10 +2,14 @@ import enum from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING -from ..file_utils import add_end_docstrings +from ..file_utils import add_end_docstrings, is_tf_available from .base import PIPELINE_INIT_ARGS, Pipeline +if is_tf_available(): + import tensorflow as tf + + class ReturnType(enum.Enum): TENSORS = 0 NEW_TEXT = 1 @@ -202,23 +206,29 @@ class TextGenerationPipeline(Pipeline): # Allow empty prompts if input_ids.shape[1] == 0: input_ids = None + in_b = 1 + else: + in_b = input_ids.shape[0] prompt_text = model_inputs.pop("prompt_text") generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL + out_b = generated_sequence.shape[0] + if self.framework == "pt": + generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) + elif self.framework == "tf": + generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): - generated_sequence = model_outputs["generated_sequence"] + generated_sequence = model_outputs["generated_sequence"][0] input_ids = model_outputs["input_ids"] prompt_text = model_outputs["prompt_text"] - if self.framework == "pt" and generated_sequence is not None: - generated_sequence = generated_sequence.cpu() generated_sequence = generated_sequence.numpy().tolist() - if return_type == ReturnType.TENSORS: - record = {"generated_token_ids": generated_sequence} - elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}: - # Decode text - record = [] - for sequence in generated_sequence: + records = [] + for sequence in generated_sequence: + if return_type == ReturnType.TENSORS: + record = {"generated_token_ids": generated_sequence} + elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}: + # Decode text text = self.tokenizer.decode( sequence, skip_special_tokens=True, @@ -242,7 +252,7 @@ class TextGenerationPipeline(Pipeline): else: all_text = text[prompt_length:] - item = {"generated_text": all_text} - record.append(item) + record = {"generated_text": all_text} + records.append(record) - return record + return records diff --git a/tests/test_pipelines_text2text_generation.py b/tests/test_pipelines_text2text_generation.py index 99ea547105..563b41954b 100644 --- a/tests/test_pipelines_text2text_generation.py +++ b/tests/test_pipelines_text2text_generation.py @@ -40,6 +40,26 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest # These are encoder decoder, they don't just append to incoming string self.assertFalse(outputs[0]["generated_text"].startswith("Something there")) + outputs = generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True) + self.assertEqual( + outputs, + [ + [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}], + [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}], + ], + ) + + outputs = generator( + ["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True + ) + self.assertEqual( + outputs, + [ + [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}], + [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}], + ], + ) + with self.assertRaises(ValueError): generator(4) diff --git a/tests/test_pipelines_text_generation.py b/tests/test_pipelines_text_generation.py index 2990d2c55b..2dfdc92698 100644 --- a/tests/test_pipelines_text_generation.py +++ b/tests/test_pipelines_text_generation.py @@ -113,6 +113,27 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM self.assertEqual(outputs, [{"generated_text": ANY(str)}]) self.assertTrue(outputs[0]["generated_text"].startswith("This is a test")) + outputs = text_generator(["This is great !", "Something else"], num_return_sequences=2, do_sample=True) + self.assertEqual( + outputs, + [ + [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}], + [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}], + ], + ) + + if text_generator.tokenizer.pad_token is not None: + outputs = text_generator( + ["This is great !", "Something else"], num_return_sequences=2, batch_size=2, do_sample=True + ) + self.assertEqual( + outputs, + [ + [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}], + [{"generated_text": ANY(str)}, {"generated_text": ANY(str)}], + ], + ) + # Empty prompt is slighly special # it requires BOS token to exist. # Special case for Pegasus which will always append EOS so will