Add assistant prefill for chat templates and TextGenerationPipeline (#33198)

* Add assistant prefill to chat templates

* Add assistant prefill to pipeline

* Add assistant prefill to pipeline

* Tweak another test that ended in assistant message

* Update tests that ended in assistant messages

* Update tests that ended in assistant messages

* Replace assistant_prefill with continue_final_message

* Allow passing continue_final_message to pipeline

* Small fixup

* Add continue_final_message as a pipeline kwarg

* Update docstrings

* Move repos to hf-internal-testing!

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Add explanatory comment

* make fixup

* Update chat templating docs to explain continue_last_message

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Matt
2024-09-02 13:23:47 +01:00
committed by GitHub
parent 2d37085817
commit 52a0213755
6 changed files with 199 additions and 23 deletions

View File

@@ -197,6 +197,43 @@ Not all models require generation prompts. Some models, like BlenderBot and LLaM
special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact
effect that `add_generation_prompt` has will depend on the template being used. effect that `add_generation_prompt` has will depend on the template being used.
## What does "continue_last_message" do?
When passing a list of messages to `apply_chat_template` or `TextGenerationPipeline`, you can choose
to format the chat so the model will continue the final message in the chat instead of starting a new one. This is done
by removing any end-of-sequence tokens that indicate the end of the final message, so that the model will simply
extend the final message when it begins to generate text. This is useful for "prefilling" the model's response.
Here's an example:
```python
chat = [
{"role": "user", "content": "Can you format the answer in JSON?"},
{"role": "assistant", "content": '{"name": "'},
]
formatted_chat = tokenizer.apply_chat_template(chat, tokenize=True, return_dict=True, continue_last_message=True)
model.generate(**formatted_chat)
```
The model will generate text that continues the JSON string, rather than starting a new message. This approach
can be very useful for improving the accuracy of the model's instruction-following when you know how you want
it to start its replies.
Because `add_generation_prompt` adds the tokens that start a new message, and `continue_last_message` removes any
end-of-message tokens from the final message, it does not make sense to use them together. As a result, you'll
get an error if you try!
<Tip>
The default behaviour of `TextGenerationPipeline` is to set `add_generation_prompt=True` so that it starts a new
message. However, if the final message in the input chat has the "assistant" role, it will assume that this message is
a prefill and switch to `continue_final_message=True` instead, because most models do not support multiple
consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_last_message`
argument when calling the pipeline.
</Tip>
## Can I use chat templates in training? ## Can I use chat templates in training?
Yes! This is a good way to ensure that the chat template matches the tokens the model sees during training. Yes! This is a good way to ensure that the chat template matches the tokens the model sees during training.

View File

@@ -131,6 +131,7 @@ class TextGenerationPipeline(Pipeline):
stop_sequence=None, stop_sequence=None,
truncation=None, truncation=None,
max_length=None, max_length=None,
continue_final_message=None,
**generate_kwargs, **generate_kwargs,
): ):
preprocess_params = {} preprocess_params = {}
@@ -165,6 +166,9 @@ class TextGenerationPipeline(Pipeline):
) )
preprocess_params["handle_long_generation"] = handle_long_generation preprocess_params["handle_long_generation"] = handle_long_generation
if continue_final_message is not None:
preprocess_params["continue_final_message"] = continue_final_message
preprocess_params.update(generate_kwargs) preprocess_params.update(generate_kwargs)
forward_params = generate_kwargs forward_params = generate_kwargs
@@ -183,6 +187,8 @@ class TextGenerationPipeline(Pipeline):
postprocess_params["return_type"] = return_type postprocess_params["return_type"] = return_type
if clean_up_tokenization_spaces is not None: if clean_up_tokenization_spaces is not None:
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
if continue_final_message is not None:
postprocess_params["continue_final_message"] = continue_final_message
if stop_sequence is not None: if stop_sequence is not None:
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False) stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
@@ -226,6 +232,10 @@ class TextGenerationPipeline(Pipeline):
*return_text* is set to True. *return_text* is set to True.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the potential extra spaces in the text output. Whether or not to clean up the potential extra spaces in the text output.
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
By default this is `True` when the final message in the input chat has the `assistant` role and
`False` otherwise, but you can manually override that behaviour by setting this flag.
prefix (`str`, *optional*): prefix (`str`, *optional*):
Prefix added to prompt. Prefix added to prompt.
handle_long_generation (`str`, *optional*): handle_long_generation (`str`, *optional*):
@@ -270,6 +280,7 @@ class TextGenerationPipeline(Pipeline):
truncation=None, truncation=None,
padding=None, padding=None,
max_length=None, max_length=None,
continue_final_message=None,
**generate_kwargs, **generate_kwargs,
): ):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
@@ -283,9 +294,14 @@ class TextGenerationPipeline(Pipeline):
if isinstance(prompt_text, Chat): if isinstance(prompt_text, Chat):
tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
inputs = self.tokenizer.apply_chat_template( inputs = self.tokenizer.apply_chat_template(
prompt_text.messages, prompt_text.messages,
add_generation_prompt=True, add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_dict=True, return_dict=True,
return_tensors=self.framework, return_tensors=self.framework,
**tokenizer_kwargs, **tokenizer_kwargs,
@@ -356,7 +372,13 @@ class TextGenerationPipeline(Pipeline):
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): def postprocess(
self,
model_outputs,
return_type=ReturnType.FULL_TEXT,
clean_up_tokenization_spaces=True,
continue_final_message=None,
):
generated_sequence = model_outputs["generated_sequence"][0] generated_sequence = model_outputs["generated_sequence"][0]
input_ids = model_outputs["input_ids"] input_ids = model_outputs["input_ids"]
prompt_text = model_outputs["prompt_text"] prompt_text = model_outputs["prompt_text"]
@@ -390,9 +412,21 @@ class TextGenerationPipeline(Pipeline):
if isinstance(prompt_text, str): if isinstance(prompt_text, str):
all_text = prompt_text + all_text all_text = prompt_text + all_text
elif isinstance(prompt_text, Chat): elif isinstance(prompt_text, Chat):
# Explicit list parsing is necessary for parsing chat datasets if continue_final_message is None:
# If the user passes a chat ending in an assistant message, we treat it as a prefill by
# default because very few models support multiple separate, consecutive assistant messages
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
if continue_final_message:
# With assistant prefill, concat onto the end of the last message
all_text = list(prompt_text.messages)[:-1] + [
{
"role": prompt_text.messages[-1]["role"],
"content": prompt_text.messages[-1]["content"] + all_text,
}
]
else:
# When we're not starting from a prefill, the output is a new assistant message
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}] all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
record = {"generated_text": all_text} record = {"generated_text": all_text}
records.append(record) records.append(record)

View File

@@ -1704,6 +1704,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
documents: Optional[List[Dict[str, str]]] = None, documents: Optional[List[Dict[str, str]]] = None,
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
add_generation_prompt: bool = False, add_generation_prompt: bool = False,
continue_final_message: bool = False,
tokenize: bool = True, tokenize: bool = True,
padding: bool = False, padding: bool = False,
truncation: bool = False, truncation: bool = False,
@@ -1737,10 +1738,16 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
chat_template (`str`, *optional*): chat_template (`str`, *optional*):
A Jinja template to use for this conversion. It is usually not necessary to pass anything to this A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
argument, as the model's template will be used by default. argument, as the model's template will be used by default.
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate add_generation_prompt (bool, *optional*):
the start of an assistant message. This is useful when you want to generate a response from the model. If this is set, a prompt with the token(s) that indicate
the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the Note that this argument will be passed to the chat template, and so it must be supported in the
template for this argument to have any effect. template for this argument to have any effect.
continue_final_message (bool, *optional*):
If this is set, the chat will be formatted so that the final
message in the chat is open-ended, without any EOS tokens. The model will continue this message
rather than starting a new one. This allows you to "prefill" part of
the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
tokenize (`bool`, defaults to `True`): tokenize (`bool`, defaults to `True`):
Whether to tokenize the output. If `False`, the output will be a string. Whether to tokenize the output. If `False`, the output will be a string.
padding (`bool`, defaults to `False`): padding (`bool`, defaults to `False`):
@@ -1803,6 +1810,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
conversations = [conversation] conversations = [conversation]
is_batched = False is_batched = False
if continue_final_message:
if add_generation_prompt:
raise ValueError(
"continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
)
if return_assistant_tokens_mask:
raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
# We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
if tools is not None: if tools is not None:
tool_schemas = [] tool_schemas = []
@@ -1849,6 +1864,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
**template_kwargs, **template_kwargs,
) )
if continue_final_message:
final_message = chat[-1]["content"]
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip()
rendered.append(rendered_chat) rendered.append(rendered_chat)
if not is_batched: if not is_batched:

View File

@@ -877,8 +877,8 @@ class CustomPipelineTest(unittest.TestCase):
# See https://github.com/huggingface/transformers/issues/31669 # See https://github.com/huggingface/transformers/issues/31669
text_generator = pipeline( text_generator = pipeline(
"text-generation", "text-generation",
model="Rocketknight1/fake-custom-model-test", model="hf-internal-testing/tiny-random-custom-architecture",
tokenizer="Rocketknight1/fake-custom-model-test", tokenizer="hf-internal-testing/tiny-random-custom-architecture",
trust_remote_code=True, trust_remote_code=True,
) )
@@ -888,8 +888,8 @@ class CustomPipelineTest(unittest.TestCase):
def test_custom_code_with_string_feature_extractor(self): def test_custom_code_with_string_feature_extractor(self):
speech_recognizer = pipeline( speech_recognizer = pipeline(
"automatic-speech-recognition", "automatic-speech-recognition",
model="Rocketknight1/fake-custom-wav2vec2", model="hf-internal-testing/fake-custom-wav2vec2",
feature_extractor="Rocketknight1/fake-custom-wav2vec2", feature_extractor="hf-internal-testing/fake-custom-wav2vec2",
trust_remote_code=True, trust_remote_code=True,
) )
@@ -899,8 +899,8 @@ class CustomPipelineTest(unittest.TestCase):
def test_custom_code_with_string_preprocessor(self): def test_custom_code_with_string_preprocessor(self):
mask_generator = pipeline( mask_generator = pipeline(
"mask-generation", "mask-generation",
model="Rocketknight1/fake-custom-sam", model="hf-internal-testing/fake-custom-sam",
processor="Rocketknight1/fake-custom-sam", processor="hf-internal-testing/fake-custom-sam",
trust_remote_code=True, trust_remote_code=True,
) )

View File

@@ -148,18 +148,16 @@ class TextGenerationPipelineTests(unittest.TestCase):
@require_torch @require_torch
def test_small_chat_model_pt(self): def test_small_chat_model_pt(self):
text_generator = pipeline( text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt" task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
) )
# Using `do_sample=False` to force deterministic output # Using `do_sample=False` to force deterministic output
chat1 = [ chat1 = [
{"role": "system", "content": "This is a system message."}, {"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"}, {"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
] ]
chat2 = [ chat2 = [
{"role": "system", "content": "This is a system message."}, {"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a second test"}, {"role": "user", "content": "This is a second test"},
{"role": "assistant", "content": "This is a reply"},
] ]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10) outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
expected_chat1 = chat1 + [ expected_chat1 = chat1 + [
@@ -179,7 +177,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
expected_chat2 = chat2 + [ expected_chat2 = chat2 + [
{ {
"role": "assistant", "role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors", "content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
} }
] ]
@@ -191,6 +189,68 @@ class TextGenerationPipelineTests(unittest.TestCase):
], ],
) )
@require_torch
def test_small_chat_model_continue_final_message(self):
# Here we check that passing a chat that ends in an assistant message is handled correctly
# by continuing the final message rather than starting a new one
text_generator = pipeline(
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
# Assert that we continued the last message and there isn't a sneaky <|im_end|>
self.assertEqual(
outputs,
[
{
"generated_text": [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{
"role": "assistant",
"content": "This is stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
},
]
}
],
)
@require_torch
def test_small_chat_model_continue_final_message_override(self):
# Here we check that passing a chat that ends in an assistant message is handled correctly
# by continuing the final message rather than starting a new one
text_generator = pipeline(
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10, continue_final_message=True)
# Assert that we continued the last message and there isn't a sneaky <|im_end|>
self.assertEqual(
outputs,
[
{
"generated_text": [
{"role": "system", "content": "This is a system message."},
{
"role": "user",
"content": "This is a test stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
},
]
}
],
)
@require_torch @require_torch
def test_small_chat_model_with_dataset_pt(self): def test_small_chat_model_with_dataset_pt(self):
from torch.utils.data import Dataset from torch.utils.data import Dataset
@@ -202,7 +262,6 @@ class TextGenerationPipelineTests(unittest.TestCase):
[ [
{"role": "system", "content": "This is a system message."}, {"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"}, {"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
], ],
] ]
@@ -213,7 +272,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
return {"text": self.data[i]} return {"text": self.data[i]}
text_generator = pipeline( text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt" task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
) )
dataset = MyDataset() dataset = MyDataset()
@@ -277,18 +336,16 @@ class TextGenerationPipelineTests(unittest.TestCase):
@require_tf @require_tf
def test_small_chat_model_tf(self): def test_small_chat_model_tf(self):
text_generator = pipeline( text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf" task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="tf"
) )
# Using `do_sample=False` to force deterministic output # Using `do_sample=False` to force deterministic output
chat1 = [ chat1 = [
{"role": "system", "content": "This is a system message."}, {"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"}, {"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
] ]
chat2 = [ chat2 = [
{"role": "system", "content": "This is a system message."}, {"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a second test"}, {"role": "user", "content": "This is a second test"},
{"role": "assistant", "content": "This is a reply"},
] ]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10) outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
expected_chat1 = chat1 + [ expected_chat1 = chat1 + [
@@ -308,7 +365,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
expected_chat2 = chat2 + [ expected_chat2 = chat2 + [
{ {
"role": "assistant", "role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors", "content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
} }
] ]

View File

@@ -1327,6 +1327,36 @@ class TokenizerTesterMixin:
[0] * (assistant_start2 - assistant_end - 1), [0] * (assistant_start2 - assistant_end - 1),
) )
@require_jinja
def test_continue_final_message(self):
dummy_template = """
{%- for message in messages %}
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] + "<|im_end|>" + "\n"}}
{%- endfor %}"""
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
)
self.assertEqual(
output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
)
prefill_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
)
# Assert that the final message is unterminated
self.assertEqual(
prefill_output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
)
@require_jinja @require_jinja
def test_chat_template_dict(self): def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}" dummy_template_1 = "{{'a'}}"