From 52a02137557963e9dd58c9be65b6cef871d3bf32 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 2 Sep 2024 13:23:47 +0100 Subject: [PATCH] Add assistant prefill for chat templates and TextGenerationPipeline (#33198) * Add assistant prefill to chat templates * Add assistant prefill to pipeline * Add assistant prefill to pipeline * Tweak another test that ended in assistant message * Update tests that ended in assistant messages * Update tests that ended in assistant messages * Replace assistant_prefill with continue_final_message * Allow passing continue_final_message to pipeline * Small fixup * Add continue_final_message as a pipeline kwarg * Update docstrings * Move repos to hf-internal-testing! * Update src/transformers/tokenization_utils_base.py Co-authored-by: Lysandre Debut * Add explanatory comment * make fixup * Update chat templating docs to explain continue_last_message --------- Co-authored-by: Lysandre Debut --- docs/source/en/chat_templating.md | 37 +++++++++ src/transformers/pipelines/text_generation.py | 44 +++++++++-- src/transformers/tokenization_utils_base.py | 22 +++++- tests/pipelines/test_pipelines_common.py | 12 +-- .../test_pipelines_text_generation.py | 77 ++++++++++++++++--- tests/test_tokenization_common.py | 30 ++++++++ 6 files changed, 199 insertions(+), 23 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index ac12b3c640..10b094e08f 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -197,6 +197,43 @@ Not all models require generation prompts. Some models, like BlenderBot and LLaM special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact effect that `add_generation_prompt` has will depend on the template being used. +## What does "continue_last_message" do? + +When passing a list of messages to `apply_chat_template` or `TextGenerationPipeline`, you can choose +to format the chat so the model will continue the final message in the chat instead of starting a new one. This is done +by removing any end-of-sequence tokens that indicate the end of the final message, so that the model will simply +extend the final message when it begins to generate text. This is useful for "prefilling" the model's response. + +Here's an example: + +```python +chat = [ + {"role": "user", "content": "Can you format the answer in JSON?"}, + {"role": "assistant", "content": '{"name": "'}, +] + +formatted_chat = tokenizer.apply_chat_template(chat, tokenize=True, return_dict=True, continue_last_message=True) +model.generate(**formatted_chat) +``` + +The model will generate text that continues the JSON string, rather than starting a new message. This approach +can be very useful for improving the accuracy of the model's instruction-following when you know how you want +it to start its replies. + +Because `add_generation_prompt` adds the tokens that start a new message, and `continue_last_message` removes any +end-of-message tokens from the final message, it does not make sense to use them together. As a result, you'll +get an error if you try! + + + +The default behaviour of `TextGenerationPipeline` is to set `add_generation_prompt=True` so that it starts a new +message. However, if the final message in the input chat has the "assistant" role, it will assume that this message is +a prefill and switch to `continue_final_message=True` instead, because most models do not support multiple +consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_last_message` +argument when calling the pipeline. + + + ## Can I use chat templates in training? Yes! This is a good way to ensure that the chat template matches the tokens the model sees during training. diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 025a81f3a4..8bd1017ffc 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -131,6 +131,7 @@ class TextGenerationPipeline(Pipeline): stop_sequence=None, truncation=None, max_length=None, + continue_final_message=None, **generate_kwargs, ): preprocess_params = {} @@ -165,6 +166,9 @@ class TextGenerationPipeline(Pipeline): ) preprocess_params["handle_long_generation"] = handle_long_generation + if continue_final_message is not None: + preprocess_params["continue_final_message"] = continue_final_message + preprocess_params.update(generate_kwargs) forward_params = generate_kwargs @@ -183,6 +187,8 @@ class TextGenerationPipeline(Pipeline): postprocess_params["return_type"] = return_type if clean_up_tokenization_spaces is not None: postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces + if continue_final_message is not None: + postprocess_params["continue_final_message"] = continue_final_message if stop_sequence is not None: stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False) @@ -226,6 +232,10 @@ class TextGenerationPipeline(Pipeline): *return_text* is set to True. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): Whether or not to clean up the potential extra spaces in the text output. + continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the + last message in the input chat rather than starting a new one, allowing you to "prefill" its response. + By default this is `True` when the final message in the input chat has the `assistant` role and + `False` otherwise, but you can manually override that behaviour by setting this flag. prefix (`str`, *optional*): Prefix added to prompt. handle_long_generation (`str`, *optional*): @@ -270,6 +280,7 @@ class TextGenerationPipeline(Pipeline): truncation=None, padding=None, max_length=None, + continue_final_message=None, **generate_kwargs, ): # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults @@ -283,9 +294,14 @@ class TextGenerationPipeline(Pipeline): if isinstance(prompt_text, Chat): tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats + # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default + # because very few models support multiple separate, consecutive assistant messages + if continue_final_message is None: + continue_final_message = prompt_text.messages[-1]["role"] == "assistant" inputs = self.tokenizer.apply_chat_template( prompt_text.messages, - add_generation_prompt=True, + add_generation_prompt=not continue_final_message, + continue_final_message=continue_final_message, return_dict=True, return_tensors=self.framework, **tokenizer_kwargs, @@ -356,7 +372,13 @@ class TextGenerationPipeline(Pipeline): generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} - def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): + def postprocess( + self, + model_outputs, + return_type=ReturnType.FULL_TEXT, + clean_up_tokenization_spaces=True, + continue_final_message=None, + ): generated_sequence = model_outputs["generated_sequence"][0] input_ids = model_outputs["input_ids"] prompt_text = model_outputs["prompt_text"] @@ -390,9 +412,21 @@ class TextGenerationPipeline(Pipeline): if isinstance(prompt_text, str): all_text = prompt_text + all_text elif isinstance(prompt_text, Chat): - # Explicit list parsing is necessary for parsing chat datasets - all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}] - + if continue_final_message is None: + # If the user passes a chat ending in an assistant message, we treat it as a prefill by + # default because very few models support multiple separate, consecutive assistant messages + continue_final_message = prompt_text.messages[-1]["role"] == "assistant" + if continue_final_message: + # With assistant prefill, concat onto the end of the last message + all_text = list(prompt_text.messages)[:-1] + [ + { + "role": prompt_text.messages[-1]["role"], + "content": prompt_text.messages[-1]["content"] + all_text, + } + ] + else: + # When we're not starting from a prefill, the output is a new assistant message + all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}] record = {"generated_text": all_text} records.append(record) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 3f3f0a3b37..608c651666 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1704,6 +1704,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): documents: Optional[List[Dict[str, str]]] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = False, + continue_final_message: bool = False, tokenize: bool = True, padding: bool = False, truncation: bool = False, @@ -1737,10 +1738,16 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): chat_template (`str`, *optional*): A Jinja template to use for this conversion. It is usually not necessary to pass anything to this argument, as the model's template will be used by default. - add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate - the start of an assistant message. This is useful when you want to generate a response from the model. + add_generation_prompt (bool, *optional*): + If this is set, a prompt with the token(s) that indicate + the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model. Note that this argument will be passed to the chat template, and so it must be supported in the template for this argument to have any effect. + continue_final_message (bool, *optional*): + If this is set, the chat will be formatted so that the final + message in the chat is open-ended, without any EOS tokens. The model will continue this message + rather than starting a new one. This allows you to "prefill" part of + the model's response for it. Cannot be used at the same time as `add_generation_prompt`. tokenize (`bool`, defaults to `True`): Whether to tokenize the output. If `False`, the output will be a string. padding (`bool`, defaults to `False`): @@ -1803,6 +1810,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): conversations = [conversation] is_batched = False + if continue_final_message: + if add_generation_prompt: + raise ValueError( + "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead." + ) + if return_assistant_tokens_mask: + raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.") + # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas if tools is not None: tool_schemas = [] @@ -1849,6 +1864,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): add_generation_prompt=add_generation_prompt, **template_kwargs, ) + if continue_final_message: + final_message = chat[-1]["content"] + rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip() rendered.append(rendered_chat) if not is_batched: diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 779cb0ac0b..ea36ae5728 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -877,8 +877,8 @@ class CustomPipelineTest(unittest.TestCase): # See https://github.com/huggingface/transformers/issues/31669 text_generator = pipeline( "text-generation", - model="Rocketknight1/fake-custom-model-test", - tokenizer="Rocketknight1/fake-custom-model-test", + model="hf-internal-testing/tiny-random-custom-architecture", + tokenizer="hf-internal-testing/tiny-random-custom-architecture", trust_remote_code=True, ) @@ -888,8 +888,8 @@ class CustomPipelineTest(unittest.TestCase): def test_custom_code_with_string_feature_extractor(self): speech_recognizer = pipeline( "automatic-speech-recognition", - model="Rocketknight1/fake-custom-wav2vec2", - feature_extractor="Rocketknight1/fake-custom-wav2vec2", + model="hf-internal-testing/fake-custom-wav2vec2", + feature_extractor="hf-internal-testing/fake-custom-wav2vec2", trust_remote_code=True, ) @@ -899,8 +899,8 @@ class CustomPipelineTest(unittest.TestCase): def test_custom_code_with_string_preprocessor(self): mask_generator = pipeline( "mask-generation", - model="Rocketknight1/fake-custom-sam", - processor="Rocketknight1/fake-custom-sam", + model="hf-internal-testing/fake-custom-sam", + processor="hf-internal-testing/fake-custom-sam", trust_remote_code=True, ) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 94132b5f55..930eb34bfb 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -148,18 +148,16 @@ class TextGenerationPipelineTests(unittest.TestCase): @require_torch def test_small_chat_model_pt(self): text_generator = pipeline( - task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt" + task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt" ) # Using `do_sample=False` to force deterministic output chat1 = [ {"role": "system", "content": "This is a system message."}, {"role": "user", "content": "This is a test"}, - {"role": "assistant", "content": "This is a reply"}, ] chat2 = [ {"role": "system", "content": "This is a system message."}, {"role": "user", "content": "This is a second test"}, - {"role": "assistant", "content": "This is a reply"}, ] outputs = text_generator(chat1, do_sample=False, max_new_tokens=10) expected_chat1 = chat1 + [ @@ -179,7 +177,7 @@ class TextGenerationPipelineTests(unittest.TestCase): expected_chat2 = chat2 + [ { "role": "assistant", - "content": " factors factors factors factors factors factors factors factors factors factors", + "content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs", } ] @@ -191,6 +189,68 @@ class TextGenerationPipelineTests(unittest.TestCase): ], ) + @require_torch + def test_small_chat_model_continue_final_message(self): + # Here we check that passing a chat that ends in an assistant message is handled correctly + # by continuing the final message rather than starting a new one + text_generator = pipeline( + task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt" + ) + # Using `do_sample=False` to force deterministic output + chat1 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + {"role": "assistant", "content": "This is"}, + ] + outputs = text_generator(chat1, do_sample=False, max_new_tokens=10) + + # Assert that we continued the last message and there isn't a sneaky <|im_end|> + self.assertEqual( + outputs, + [ + { + "generated_text": [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + { + "role": "assistant", + "content": "This is stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs", + }, + ] + } + ], + ) + + @require_torch + def test_small_chat_model_continue_final_message_override(self): + # Here we check that passing a chat that ends in an assistant message is handled correctly + # by continuing the final message rather than starting a new one + text_generator = pipeline( + task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt" + ) + # Using `do_sample=False` to force deterministic output + chat1 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + ] + outputs = text_generator(chat1, do_sample=False, max_new_tokens=10, continue_final_message=True) + + # Assert that we continued the last message and there isn't a sneaky <|im_end|> + self.assertEqual( + outputs, + [ + { + "generated_text": [ + {"role": "system", "content": "This is a system message."}, + { + "role": "user", + "content": "This is a test stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs", + }, + ] + } + ], + ) + @require_torch def test_small_chat_model_with_dataset_pt(self): from torch.utils.data import Dataset @@ -202,7 +262,6 @@ class TextGenerationPipelineTests(unittest.TestCase): [ {"role": "system", "content": "This is a system message."}, {"role": "user", "content": "This is a test"}, - {"role": "assistant", "content": "This is a reply"}, ], ] @@ -213,7 +272,7 @@ class TextGenerationPipelineTests(unittest.TestCase): return {"text": self.data[i]} text_generator = pipeline( - task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt" + task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt" ) dataset = MyDataset() @@ -277,18 +336,16 @@ class TextGenerationPipelineTests(unittest.TestCase): @require_tf def test_small_chat_model_tf(self): text_generator = pipeline( - task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf" + task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="tf" ) # Using `do_sample=False` to force deterministic output chat1 = [ {"role": "system", "content": "This is a system message."}, {"role": "user", "content": "This is a test"}, - {"role": "assistant", "content": "This is a reply"}, ] chat2 = [ {"role": "system", "content": "This is a system message."}, {"role": "user", "content": "This is a second test"}, - {"role": "assistant", "content": "This is a reply"}, ] outputs = text_generator(chat1, do_sample=False, max_new_tokens=10) expected_chat1 = chat1 + [ @@ -308,7 +365,7 @@ class TextGenerationPipelineTests(unittest.TestCase): expected_chat2 = chat2 + [ { "role": "assistant", - "content": " factors factors factors factors factors factors factors factors factors factors", + "content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs", } ] diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index f1bcfe3929..64c860e3fc 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1327,6 +1327,36 @@ class TokenizerTesterMixin: [0] * (assistant_start2 - assistant_end - 1), ) + @require_jinja + def test_continue_final_message(self): + dummy_template = """ + {%- for message in messages %} + {{- "<|im_start|>" + message['role'] + "\n" + message['content'] + "<|im_end|>" + "\n"}} + {%- endfor %}""" + dummy_conversation = [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "assistant message"}, + ] + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False + ) + self.assertEqual( + output, + "<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n", + ) + prefill_output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True + ) + # Assert that the final message is unterminated + self.assertEqual( + prefill_output, + "<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message", + ) + @require_jinja def test_chat_template_dict(self): dummy_template_1 = "{{'a'}}"