Add add_generation_prompt argument to apply_chat_template (#26573)
* Add add_generation_prompt argument to apply_chat_template * Add add_generation_prompt argument to apply_chat_template and update default templates * Fix typo * Add generation prompts section to chat templating guide * Add generation prompts section to chat templating guide * Minor style fix
This commit is contained in:
@@ -218,10 +218,11 @@ input formats. Our default template for models that don't have a class-specific
|
|||||||
{% endfor %}
|
{% endfor %}
|
||||||
```
|
```
|
||||||
|
|
||||||
If you like this one, here it is in one-liner form, ready to copy into your code:
|
If you like this one, here it is in one-liner form, ready to copy into your code. The one-liner also includes
|
||||||
|
handy support for "generation prompts" - see the next section for more!
|
||||||
|
|
||||||
```
|
```
|
||||||
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
|
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %})"
|
||||||
```
|
```
|
||||||
|
|
||||||
This template wraps each message in `<|im_start|>` and `<|im_end|>` tokens, and simply writes the role as a string, which
|
This template wraps each message in `<|im_start|>` and `<|im_end|>` tokens, and simply writes the role as a string, which
|
||||||
@@ -240,6 +241,56 @@ The "user", "system" and "assistant" roles are the standard for chat, and we rec
|
|||||||
particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited
|
particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited
|
||||||
to these roles - templating is extremely flexible, and any string can be a role.
|
to these roles - templating is extremely flexible, and any string can be a role.
|
||||||
|
|
||||||
|
## What are "generation prompts"?
|
||||||
|
|
||||||
|
You may notice that the `apply_chat_template` method has an `add_generation_prompt` argument. This argument tells
|
||||||
|
the template to add tokens that indicate the start of a bot response. For example, consider the following chat:
|
||||||
|
|
||||||
|
```python
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hi there!"},
|
||||||
|
{"role": "assistant", "content": "Nice to meet you!"},
|
||||||
|
{"role": "user", "content": "Can I ask a question?"}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Here's what this will look like without a generation prompt, using the ChatML template we described above:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>> tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||||
|
"""<|im_start|>user
|
||||||
|
Hi there!<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
Nice to meet you!<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
Can I ask a question?<|im_end|>
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
And here's what it looks like **with** a generation prompt:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>> tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
"""<|im_start|>user
|
||||||
|
Hi there!<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
Nice to meet you!<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
Can I ask a question?<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that this time, we've added the tokens that indicate the start of a bot response. This ensures that when the model
|
||||||
|
generates text it will write a bot response instead of doing something unexpected, like continuing the user's
|
||||||
|
message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a
|
||||||
|
special kind of text to them! You need to guide them with the appropriate control tokens so they know what they're
|
||||||
|
supposed to be doing.
|
||||||
|
|
||||||
|
Not all models require generation prompts. Some models, like BlenderBot and LLaMA, don't have any
|
||||||
|
special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact
|
||||||
|
effect that `add_generation_prompt` has will depend on the template being used.
|
||||||
|
|
||||||
## I want to use chat templates! How should I get started?
|
## I want to use chat templates! How should I get started?
|
||||||
|
|
||||||
If you have any chat models, you should set their `tokenizer.chat_template` attribute and test it using
|
If you have any chat models, you should set their `tokenizer.chat_template` attribute and test it using
|
||||||
|
|||||||
@@ -181,7 +181,10 @@ class GPTNeoXJapaneseTokenizer(PreTrainedTokenizer):
|
|||||||
A simple chat template that just adds BOS/EOS tokens around messages while discarding role information.
|
A simple chat template that just adds BOS/EOS tokens around messages while discarding role information.
|
||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
"{% for message in messages %}" "{{ bos_token + eos_token + message.content + eos_token }}" "{% endfor %}"
|
"{% for message in messages %}"
|
||||||
|
"{{ bos_token + eos_token + message.content + eos_token }}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% if add_generation_prompt %} {{ bos_token + eos_token }} {% endif %}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
|||||||
@@ -262,7 +262,7 @@ class ConversationalPipeline(Pipeline):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]:
|
def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]:
|
||||||
input_ids = self.tokenizer.apply_chat_template(conversation)
|
input_ids = self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True)
|
||||||
|
|
||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
input_ids = torch.LongTensor([input_ids])
|
input_ids = torch.LongTensor([input_ids])
|
||||||
|
|||||||
@@ -1718,6 +1718,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
self,
|
self,
|
||||||
conversation: Union[List[Dict[str, str]], "Conversation"],
|
conversation: Union[List[Dict[str, str]], "Conversation"],
|
||||||
chat_template: Optional[str] = None,
|
chat_template: Optional[str] = None,
|
||||||
|
add_generation_prompt: bool = False,
|
||||||
tokenize: bool = True,
|
tokenize: bool = True,
|
||||||
padding: bool = False,
|
padding: bool = False,
|
||||||
truncation: bool = False,
|
truncation: bool = False,
|
||||||
@@ -1736,6 +1737,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
with "role" and "content" keys, representing the chat history so far.
|
with "role" and "content" keys, representing the chat history so far.
|
||||||
chat_template (str, *optional*): A Jinja template to use for this conversion. If
|
chat_template (str, *optional*): A Jinja template to use for this conversion. If
|
||||||
this is not passed, the model's default chat template will be used instead.
|
this is not passed, the model's default chat template will be used instead.
|
||||||
|
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
|
||||||
|
the start of an assistant message. This is useful when you want to generate a response from the model.
|
||||||
|
Note that this argument will be passed to the chat template, and so it must be supported in the
|
||||||
|
template for this argument to have any effect.
|
||||||
tokenize (`bool`, defaults to `True`):
|
tokenize (`bool`, defaults to `True`):
|
||||||
Whether to tokenize the output. If `False`, the output will be a string.
|
Whether to tokenize the output. If `False`, the output will be a string.
|
||||||
padding (`bool`, defaults to `False`):
|
padding (`bool`, defaults to `False`):
|
||||||
@@ -1773,7 +1778,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
# Compilation function uses a cache to avoid recompiling the same template
|
# Compilation function uses a cache to avoid recompiling the same template
|
||||||
compiled_template = self._compile_jinja_template(chat_template)
|
compiled_template = self._compile_jinja_template(chat_template)
|
||||||
|
|
||||||
rendered = compiled_template.render(messages=conversation, **self.special_tokens_map)
|
rendered = compiled_template.render(
|
||||||
|
messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
|
||||||
|
)
|
||||||
|
|
||||||
if padding is True:
|
if padding is True:
|
||||||
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
|
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
|
||||||
@@ -1815,6 +1822,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
"{% for message in messages %}"
|
"{% for message in messages %}"
|
||||||
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
||||||
"{% endfor %}"
|
"{% endfor %}"
|
||||||
|
"{% if add_generation_prompt %}"
|
||||||
|
"{{ '<|im_start|>assistant\n' }}"
|
||||||
|
"{% endif %}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user