From dc540dd316819dc77d8c9d719c89c74df42b4d05 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 29 Oct 2021 15:29:28 +0200 Subject: [PATCH] Adding `handle_long_generation` paramters for `text-generation` pipeline. (#14118) * Adding `handle_long_generation` paramters for `text-generation` pipeline. * More error handling * Fixing tests by dropping tf support on this functionality, it needs `max_new_tokens` to make it possible to understand user's intent. Otherwise, `max_length` == `tokenizer.model_max_length` < input_ids.shape[0]. * Fixing doc ? * Doc ? * Remove link from doc. * Catched an issue on roberta. * Damn doc. * Non BC proposal ? * Cleaning the fix ? * Finally using only a test override. * Don't need to modify this. * Bad print. --- .../models/reformer/modeling_reformer.py | 2 +- src/transformers/pipelines/text_generation.py | 45 ++++++++++++++++++- tests/test_pipelines_common.py | 4 +- tests/test_pipelines_text_generation.py | 21 +++++++++ 4 files changed, 68 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index c7ee43a566..528875b4aa 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -254,7 +254,7 @@ class ReformerEmbeddings(nn.Module): if position_ids.shape[-1] > self.max_position_embeddings: raise ValueError( - f"Sequence Length: {position_ids.shape[-1]} has to be larger equal than " + f"Sequence Length: {position_ids.shape[-1]} has to be less or equal than " f"config.max_position_embeddings {self.max_position_embeddings}." ) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 03d9621b4c..0179e82b08 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -75,6 +75,7 @@ class TextGenerationPipeline(Pipeline): return_type=None, clean_up_tokenization_spaces=None, prefix=None, + handle_long_generation=None, **generate_kwargs ): preprocess_params = {} @@ -85,14 +86,24 @@ class TextGenerationPipeline(Pipeline): prefix, padding=False, add_special_tokens=False, return_tensors=self.framework ) prefix_length = prefix_inputs["input_ids"].shape[-1] - if "max_length" in generate_kwargs: + + if "max_new_tokens" in generate_kwargs: + pass + elif "max_length" in generate_kwargs: generate_kwargs["max_length"] += prefix_length else: generate_kwargs["max_length"] = self.model.config.max_length + prefix_length if "min_length" in generate_kwargs: generate_kwargs["min_length"] += prefix_length + if handle_long_generation is not None: + if handle_long_generation not in {"hole"}: + raise ValueError( + f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected [None, 'hole']" + ) + preprocess_params["handle_long_generation"] = handle_long_generation + preprocess_params.update(generate_kwargs) forward_params = generate_kwargs postprocess_params = {} @@ -136,6 +147,16 @@ class TextGenerationPipeline(Pipeline): Whether or not to clean up the potential extra spaces in the text output. prefix (:obj:`str`, `optional`): Prefix added to prompt. + handle_long_generation (:obj:`str`, `optional`): + By default, this pipelines does not handle long generation (ones that exceed in one form or the other + the model maximum length). There is no perfect way to adress this (more info + :https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common + strategies to work around that problem depending on your use case. + + - :obj:`None` : default strategy where nothing in particular happens + - :obj:`"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might + truncate a lot of the prompt and not suitable when generation exceed the model capacity) + generate_kwargs: Additional keyword arguments to pass along to the generate method of the model (see the generate method corresponding to your framework `here <./model.html#generative-models>`__). @@ -149,11 +170,31 @@ class TextGenerationPipeline(Pipeline): """ return super().__call__(text_inputs, **kwargs) - def preprocess(self, prompt_text, prefix=""): + def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs): inputs = self.tokenizer( prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework ) inputs["prompt_text"] = prompt_text + + if handle_long_generation == "hole": + cur_len = inputs["input_ids"].shape[-1] + if "max_new_tokens" in generate_kwargs: + new_tokens = generate_kwargs["max_new_tokens"] + else: + new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len + if new_tokens < 0: + raise ValueError("We cannot infer how many new tokens are expected") + if cur_len + new_tokens > self.tokenizer.model_max_length: + keep_length = self.tokenizer.model_max_length - new_tokens + if keep_length <= 0: + raise ValueError( + "We cannot use `hole` to handle this generation the number of desired tokens exceeds the models max length" + ) + + inputs["input_ids"] = inputs["input_ids"][:, -keep_length:] + if "attention_mask" in inputs: + inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:] + return inputs def _forward(self, model_inputs, **generate_kwargs): diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 5cb9fb6ecb..ba4397b5e9 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -143,7 +143,9 @@ class PipelineTestCaseMeta(type): try: tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint) # XLNet actually defines it as -1. - if ( + if model.config.__class__.__name__ == "RobertaConfig": + tokenizer.model_max_length = model.config.max_position_embeddings - 2 + elif ( hasattr(model.config, "max_position_embeddings") and model.config.max_position_embeddings > 0 ): diff --git a/tests/test_pipelines_text_generation.py b/tests/test_pipelines_text_generation.py index ebe71a5591..2990d2c55b 100644 --- a/tests/test_pipelines_text_generation.py +++ b/tests/test_pipelines_text_generation.py @@ -123,3 +123,24 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM else: with self.assertRaises((ValueError, AssertionError)): outputs = text_generator("") + + if text_generator.framework == "tf": + # TF generation does not support max_new_tokens, and it's impossible + # to control long generation with only max_length without + # fancy calculation, dismissing tests for now. + return + # We don't care about infinite range models. + # They already work. + if tokenizer.model_max_length < 10000: + # Handling of large generations + with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)): + text_generator("This is a test" * 500, max_new_tokens=20) + + outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20) + # Hole strategy cannot work + with self.assertRaises(ValueError): + text_generator( + "This is a test" * 500, + handle_long_generation="hole", + max_new_tokens=tokenizer.model_max_length + 10, + )