From 8b46c5bcfc56d3e5ae6409a0af64e2bf57229484 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 4 Oct 2023 15:15:29 +0100 Subject: [PATCH] Add add_generation_prompt argument to apply_chat_template (#26573) * Add add_generation_prompt argument to apply_chat_template * Add add_generation_prompt argument to apply_chat_template and update default templates * Fix typo * Add generation prompts section to chat templating guide * Add generation prompts section to chat templating guide * Minor style fix --- docs/source/en/chat_templating.md | 55 ++++++++++++++++++- .../tokenization_gpt_neox_japanese.py | 5 +- src/transformers/pipelines/conversational.py | 2 +- src/transformers/tokenization_utils_base.py | 12 +++- 4 files changed, 69 insertions(+), 5 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index f568c9949e..af4ec4c06b 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -218,10 +218,11 @@ input formats. Our default template for models that don't have a class-specific {% endfor %} ``` -If you like this one, here it is in one-liner form, ready to copy into your code: +If you like this one, here it is in one-liner form, ready to copy into your code. The one-liner also includes +handy support for "generation prompts" - see the next section for more! ``` -tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" +tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %})" ``` This template wraps each message in `<|im_start|>` and `<|im_end|>` tokens, and simply writes the role as a string, which @@ -240,6 +241,56 @@ The "user", "system" and "assistant" roles are the standard for chat, and we rec particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited to these roles - templating is extremely flexible, and any string can be a role. +## What are "generation prompts"? + +You may notice that the `apply_chat_template` method has an `add_generation_prompt` argument. This argument tells +the template to add tokens that indicate the start of a bot response. For example, consider the following chat: + +```python +messages = [ + {"role": "user", "content": "Hi there!"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "Can I ask a question?"} +] +``` + +Here's what this will look like without a generation prompt, using the ChatML template we described above: + +```python +>> tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) +"""<|im_start|>user +Hi there!<|im_end|> +<|im_start|>assistant +Nice to meet you!<|im_end|> +<|im_start|>user +Can I ask a question?<|im_end|> +""" +``` + +And here's what it looks like **with** a generation prompt: + +```python +>> tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) +"""<|im_start|>user +Hi there!<|im_end|> +<|im_start|>assistant +Nice to meet you!<|im_end|> +<|im_start|>user +Can I ask a question?<|im_end|> +<|im_start|>assistant +""" +``` + +Note that this time, we've added the tokens that indicate the start of a bot response. This ensures that when the model +generates text it will write a bot response instead of doing something unexpected, like continuing the user's +message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a +special kind of text to them! You need to guide them with the appropriate control tokens so they know what they're +supposed to be doing. + +Not all models require generation prompts. Some models, like BlenderBot and LLaMA, don't have any +special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact +effect that `add_generation_prompt` has will depend on the template being used. + ## I want to use chat templates! How should I get started? If you have any chat models, you should set their `tokenizer.chat_template` attribute and test it using diff --git a/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py index 7fca57d4c1..c035087948 100644 --- a/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py @@ -181,7 +181,10 @@ class GPTNeoXJapaneseTokenizer(PreTrainedTokenizer): A simple chat template that just adds BOS/EOS tokens around messages while discarding role information. """ return ( - "{% for message in messages %}" "{{ bos_token + eos_token + message.content + eos_token }}" "{% endfor %}" + "{% for message in messages %}" + "{{ bos_token + eos_token + message.content + eos_token }}" + "{% endfor %}" + "{% if add_generation_prompt %} {{ bos_token + eos_token }} {% endif %}" ) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index 3d037799c8..639ad868f2 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -262,7 +262,7 @@ class ConversationalPipeline(Pipeline): return outputs def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]: - input_ids = self.tokenizer.apply_chat_template(conversation) + input_ids = self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True) if self.framework == "pt": input_ids = torch.LongTensor([input_ids]) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index c2285ad479..cf30c7695f 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1718,6 +1718,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): self, conversation: Union[List[Dict[str, str]], "Conversation"], chat_template: Optional[str] = None, + add_generation_prompt: bool = False, tokenize: bool = True, padding: bool = False, truncation: bool = False, @@ -1736,6 +1737,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): with "role" and "content" keys, representing the chat history so far. chat_template (str, *optional*): A Jinja template to use for this conversion. If this is not passed, the model's default chat template will be used instead. + add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate + the start of an assistant message. This is useful when you want to generate a response from the model. + Note that this argument will be passed to the chat template, and so it must be supported in the + template for this argument to have any effect. tokenize (`bool`, defaults to `True`): Whether to tokenize the output. If `False`, the output will be a string. padding (`bool`, defaults to `False`): @@ -1773,7 +1778,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): # Compilation function uses a cache to avoid recompiling the same template compiled_template = self._compile_jinja_template(chat_template) - rendered = compiled_template.render(messages=conversation, **self.special_tokens_map) + rendered = compiled_template.render( + messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map + ) if padding is True: padding = "max_length" # There's only one sequence here, so "longest" makes no sense @@ -1815,6 +1822,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): "{% for message in messages %}" "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|im_start|>assistant\n' }}" + "{% endif %}" ) @classmethod