From e396358104a3631be42617168c09fa8894148b0f Mon Sep 17 00:00:00 2001 From: Karim Foda <35491698+KMFODA@users.noreply.github.com> Date: Fri, 30 Sep 2022 16:26:51 +0300 Subject: [PATCH] Add stop sequence to text generation pipeline (#18444) --- src/transformers/generation_utils.py | 1 - .../pipelines/text2text_generation.py | 11 ++++++++++ src/transformers/pipelines/text_generation.py | 11 ++++++++++ tests/generation/test_generation_utils.py | 20 +++++++++++++++++++ .../test_pipelines_text_generation.py | 12 +++++++++++ 5 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 71db5532ea..79460c1cad 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1343,7 +1343,6 @@ class GenerationMixin: stopping_criteria = self._get_stopping_criteria( max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria ) - # 9. go into different generation modes if is_greedy_gen_mode: if num_return_sequences > 1: diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index 97cbc1a395..2247e57929 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -1,4 +1,5 @@ import enum +import warnings from ..tokenization_utils import TruncationStrategy from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging @@ -59,6 +60,7 @@ class Text2TextGenerationPipeline(Pipeline): return_type=None, clean_up_tokenization_spaces=None, truncation=None, + stop_sequence=None, **generate_kwargs ): preprocess_params = {} @@ -76,6 +78,15 @@ class Text2TextGenerationPipeline(Pipeline): if clean_up_tokenization_spaces is not None: postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces + if stop_sequence is not None: + stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False) + if len(stop_sequence_ids) > 1: + warnings.warn( + "Stopping on a multiple token sequence is not yet supported on transformers. The first token of" + " the stop sequence will be used as the stop sequence string in the interim." + ) + generate_kwargs["eos_token_id"] = stop_sequence_ids[0] + return preprocess_params, forward_params, postprocess_params def check_inputs(self, input_length: int, min_length: int, max_length: int): diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 7d15316492..4cb78c9bbe 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,4 +1,5 @@ import enum +import warnings from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING @@ -80,6 +81,7 @@ class TextGenerationPipeline(Pipeline): clean_up_tokenization_spaces=None, prefix=None, handle_long_generation=None, + stop_sequence=None, **generate_kwargs ): preprocess_params = {} @@ -121,6 +123,15 @@ class TextGenerationPipeline(Pipeline): if clean_up_tokenization_spaces is not None: postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces + if stop_sequence is not None: + stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False) + if len(stop_sequence_ids) > 1: + warnings.warn( + "Stopping on a multiple token sequence is not yet supported on transformers. The first token of" + " the stop sequence will be used as the stop sequence string in the interim." + ) + generate_kwargs["eos_token_id"] = stop_sequence_ids[0] + return preprocess_params, forward_params, postprocess_params # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index e8cb57ccf3..f48cfff83c 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -37,6 +37,7 @@ if is_torch_available(): Speech2TextForConditionalGeneration, SpeechEncoderDecoderModel, VisionEncoderDecoderModel, + pipeline, top_k_top_p_filtering, ) from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint @@ -1979,6 +1980,25 @@ class GenerationIntegrationTests(unittest.TestCase): [1, 18], ) + def test_stop_sequence_stopping_criteria(self): + + prompt = """Hello I believe in""" + generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") + output = generator(prompt) + self.assertEqual( + output, + [ + { + "generated_text": ( + "Hello I believe in in in number number number number number number number number number" + ) + } + ], + ) + + output = generator(prompt, stop_sequence=" number") + self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}]) + def test_custom_logits_processor(self): bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index a26ed56d4c..ac6d122559 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -147,6 +147,18 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer) return text_generator, ["This is a test", "Another test"] + def test_stop_sequence_stopping_criteria(self): + prompt = """Hello I believe in""" + text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2") + output = text_generator(prompt) + self.assertEqual( + output, + [{"generated_text": "Hello I believe in fe fe fe fe fe fe fe fe fe fe fe fe"}], + ) + + output = text_generator(prompt, stop_sequence=" fe") + self.assertEqual(output, [{"generated_text": "Hello I believe in fe"}]) + def run_pipeline_test(self, text_generator, _): model = text_generator.model tokenizer = text_generator.tokenizer