Add assistant prefill for chat templates and TextGenerationPipeline (#33198)
* Add assistant prefill to chat templates * Add assistant prefill to pipeline * Add assistant prefill to pipeline * Tweak another test that ended in assistant message * Update tests that ended in assistant messages * Update tests that ended in assistant messages * Replace assistant_prefill with continue_final_message * Allow passing continue_final_message to pipeline * Small fixup * Add continue_final_message as a pipeline kwarg * Update docstrings * Move repos to hf-internal-testing! * Update src/transformers/tokenization_utils_base.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Add explanatory comment * make fixup * Update chat templating docs to explain continue_last_message --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -197,6 +197,43 @@ Not all models require generation prompts. Some models, like BlenderBot and LLaM
|
||||
special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact
|
||||
effect that `add_generation_prompt` has will depend on the template being used.
|
||||
|
||||
## What does "continue_last_message" do?
|
||||
|
||||
When passing a list of messages to `apply_chat_template` or `TextGenerationPipeline`, you can choose
|
||||
to format the chat so the model will continue the final message in the chat instead of starting a new one. This is done
|
||||
by removing any end-of-sequence tokens that indicate the end of the final message, so that the model will simply
|
||||
extend the final message when it begins to generate text. This is useful for "prefilling" the model's response.
|
||||
|
||||
Here's an example:
|
||||
|
||||
```python
|
||||
chat = [
|
||||
{"role": "user", "content": "Can you format the answer in JSON?"},
|
||||
{"role": "assistant", "content": '{"name": "'},
|
||||
]
|
||||
|
||||
formatted_chat = tokenizer.apply_chat_template(chat, tokenize=True, return_dict=True, continue_last_message=True)
|
||||
model.generate(**formatted_chat)
|
||||
```
|
||||
|
||||
The model will generate text that continues the JSON string, rather than starting a new message. This approach
|
||||
can be very useful for improving the accuracy of the model's instruction-following when you know how you want
|
||||
it to start its replies.
|
||||
|
||||
Because `add_generation_prompt` adds the tokens that start a new message, and `continue_last_message` removes any
|
||||
end-of-message tokens from the final message, it does not make sense to use them together. As a result, you'll
|
||||
get an error if you try!
|
||||
|
||||
<Tip>
|
||||
|
||||
The default behaviour of `TextGenerationPipeline` is to set `add_generation_prompt=True` so that it starts a new
|
||||
message. However, if the final message in the input chat has the "assistant" role, it will assume that this message is
|
||||
a prefill and switch to `continue_final_message=True` instead, because most models do not support multiple
|
||||
consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_last_message`
|
||||
argument when calling the pipeline.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Can I use chat templates in training?
|
||||
|
||||
Yes! This is a good way to ensure that the chat template matches the tokens the model sees during training.
|
||||
|
||||
@@ -131,6 +131,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
stop_sequence=None,
|
||||
truncation=None,
|
||||
max_length=None,
|
||||
continue_final_message=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
preprocess_params = {}
|
||||
@@ -165,6 +166,9 @@ class TextGenerationPipeline(Pipeline):
|
||||
)
|
||||
preprocess_params["handle_long_generation"] = handle_long_generation
|
||||
|
||||
if continue_final_message is not None:
|
||||
preprocess_params["continue_final_message"] = continue_final_message
|
||||
|
||||
preprocess_params.update(generate_kwargs)
|
||||
forward_params = generate_kwargs
|
||||
|
||||
@@ -183,6 +187,8 @@ class TextGenerationPipeline(Pipeline):
|
||||
postprocess_params["return_type"] = return_type
|
||||
if clean_up_tokenization_spaces is not None:
|
||||
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
|
||||
if continue_final_message is not None:
|
||||
postprocess_params["continue_final_message"] = continue_final_message
|
||||
|
||||
if stop_sequence is not None:
|
||||
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
|
||||
@@ -226,6 +232,10 @@ class TextGenerationPipeline(Pipeline):
|
||||
*return_text* is set to True.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
|
||||
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
|
||||
By default this is `True` when the final message in the input chat has the `assistant` role and
|
||||
`False` otherwise, but you can manually override that behaviour by setting this flag.
|
||||
prefix (`str`, *optional*):
|
||||
Prefix added to prompt.
|
||||
handle_long_generation (`str`, *optional*):
|
||||
@@ -270,6 +280,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
truncation=None,
|
||||
padding=None,
|
||||
max_length=None,
|
||||
continue_final_message=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
|
||||
@@ -283,9 +294,14 @@ class TextGenerationPipeline(Pipeline):
|
||||
|
||||
if isinstance(prompt_text, Chat):
|
||||
tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats
|
||||
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
|
||||
# because very few models support multiple separate, consecutive assistant messages
|
||||
if continue_final_message is None:
|
||||
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
prompt_text.messages,
|
||||
add_generation_prompt=True,
|
||||
add_generation_prompt=not continue_final_message,
|
||||
continue_final_message=continue_final_message,
|
||||
return_dict=True,
|
||||
return_tensors=self.framework,
|
||||
**tokenizer_kwargs,
|
||||
@@ -356,7 +372,13 @@ class TextGenerationPipeline(Pipeline):
|
||||
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
||||
|
||||
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
|
||||
def postprocess(
|
||||
self,
|
||||
model_outputs,
|
||||
return_type=ReturnType.FULL_TEXT,
|
||||
clean_up_tokenization_spaces=True,
|
||||
continue_final_message=None,
|
||||
):
|
||||
generated_sequence = model_outputs["generated_sequence"][0]
|
||||
input_ids = model_outputs["input_ids"]
|
||||
prompt_text = model_outputs["prompt_text"]
|
||||
@@ -390,9 +412,21 @@ class TextGenerationPipeline(Pipeline):
|
||||
if isinstance(prompt_text, str):
|
||||
all_text = prompt_text + all_text
|
||||
elif isinstance(prompt_text, Chat):
|
||||
# Explicit list parsing is necessary for parsing chat datasets
|
||||
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
|
||||
|
||||
if continue_final_message is None:
|
||||
# If the user passes a chat ending in an assistant message, we treat it as a prefill by
|
||||
# default because very few models support multiple separate, consecutive assistant messages
|
||||
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
|
||||
if continue_final_message:
|
||||
# With assistant prefill, concat onto the end of the last message
|
||||
all_text = list(prompt_text.messages)[:-1] + [
|
||||
{
|
||||
"role": prompt_text.messages[-1]["role"],
|
||||
"content": prompt_text.messages[-1]["content"] + all_text,
|
||||
}
|
||||
]
|
||||
else:
|
||||
# When we're not starting from a prefill, the output is a new assistant message
|
||||
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
|
||||
record = {"generated_text": all_text}
|
||||
records.append(record)
|
||||
|
||||
|
||||
@@ -1704,6 +1704,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
documents: Optional[List[Dict[str, str]]] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
add_generation_prompt: bool = False,
|
||||
continue_final_message: bool = False,
|
||||
tokenize: bool = True,
|
||||
padding: bool = False,
|
||||
truncation: bool = False,
|
||||
@@ -1737,10 +1738,16 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
chat_template (`str`, *optional*):
|
||||
A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
|
||||
argument, as the model's template will be used by default.
|
||||
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
|
||||
the start of an assistant message. This is useful when you want to generate a response from the model.
|
||||
add_generation_prompt (bool, *optional*):
|
||||
If this is set, a prompt with the token(s) that indicate
|
||||
the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
|
||||
Note that this argument will be passed to the chat template, and so it must be supported in the
|
||||
template for this argument to have any effect.
|
||||
continue_final_message (bool, *optional*):
|
||||
If this is set, the chat will be formatted so that the final
|
||||
message in the chat is open-ended, without any EOS tokens. The model will continue this message
|
||||
rather than starting a new one. This allows you to "prefill" part of
|
||||
the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
|
||||
tokenize (`bool`, defaults to `True`):
|
||||
Whether to tokenize the output. If `False`, the output will be a string.
|
||||
padding (`bool`, defaults to `False`):
|
||||
@@ -1803,6 +1810,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
conversations = [conversation]
|
||||
is_batched = False
|
||||
|
||||
if continue_final_message:
|
||||
if add_generation_prompt:
|
||||
raise ValueError(
|
||||
"continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
|
||||
)
|
||||
if return_assistant_tokens_mask:
|
||||
raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
|
||||
|
||||
# We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
|
||||
if tools is not None:
|
||||
tool_schemas = []
|
||||
@@ -1849,6 +1864,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**template_kwargs,
|
||||
)
|
||||
if continue_final_message:
|
||||
final_message = chat[-1]["content"]
|
||||
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip()
|
||||
rendered.append(rendered_chat)
|
||||
|
||||
if not is_batched:
|
||||
|
||||
@@ -877,8 +877,8 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
# See https://github.com/huggingface/transformers/issues/31669
|
||||
text_generator = pipeline(
|
||||
"text-generation",
|
||||
model="Rocketknight1/fake-custom-model-test",
|
||||
tokenizer="Rocketknight1/fake-custom-model-test",
|
||||
model="hf-internal-testing/tiny-random-custom-architecture",
|
||||
tokenizer="hf-internal-testing/tiny-random-custom-architecture",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
@@ -888,8 +888,8 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
def test_custom_code_with_string_feature_extractor(self):
|
||||
speech_recognizer = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="Rocketknight1/fake-custom-wav2vec2",
|
||||
feature_extractor="Rocketknight1/fake-custom-wav2vec2",
|
||||
model="hf-internal-testing/fake-custom-wav2vec2",
|
||||
feature_extractor="hf-internal-testing/fake-custom-wav2vec2",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
@@ -899,8 +899,8 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
def test_custom_code_with_string_preprocessor(self):
|
||||
mask_generator = pipeline(
|
||||
"mask-generation",
|
||||
model="Rocketknight1/fake-custom-sam",
|
||||
processor="Rocketknight1/fake-custom-sam",
|
||||
model="hf-internal-testing/fake-custom-sam",
|
||||
processor="hf-internal-testing/fake-custom-sam",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -148,18 +148,16 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_small_chat_model_pt(self):
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
chat2 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a second test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||
expected_chat1 = chat1 + [
|
||||
@@ -179,7 +177,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
expected_chat2 = chat2 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||
}
|
||||
]
|
||||
|
||||
@@ -191,6 +189,68 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_continue_final_message(self):
|
||||
# Here we check that passing a chat that ends in an assistant message is handled correctly
|
||||
# by continuing the final message rather than starting a new one
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is"},
|
||||
]
|
||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||
|
||||
# Assert that we continued the last message and there isn't a sneaky <|im_end|>
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"generated_text": [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "This is stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||
},
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_continue_final_message_override(self):
|
||||
# Here we check that passing a chat that ends in an assistant message is handled correctly
|
||||
# by continuing the final message rather than starting a new one
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
]
|
||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10, continue_final_message=True)
|
||||
|
||||
# Assert that we continued the last message and there isn't a sneaky <|im_end|>
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"generated_text": [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is a test stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||
},
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_with_dataset_pt(self):
|
||||
from torch.utils.data import Dataset
|
||||
@@ -202,7 +262,6 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
[
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
],
|
||||
]
|
||||
|
||||
@@ -213,7 +272,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
return {"text": self.data[i]}
|
||||
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt"
|
||||
)
|
||||
|
||||
dataset = MyDataset()
|
||||
@@ -277,18 +336,16 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
@require_tf
|
||||
def test_small_chat_model_tf(self):
|
||||
text_generator = pipeline(
|
||||
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf"
|
||||
task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="tf"
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat1 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
chat2 = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a second test"},
|
||||
{"role": "assistant", "content": "This is a reply"},
|
||||
]
|
||||
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
|
||||
expected_chat1 = chat1 + [
|
||||
@@ -308,7 +365,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
expected_chat2 = chat2 + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " factors factors factors factors factors factors factors factors factors factors",
|
||||
"content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@@ -1327,6 +1327,36 @@ class TokenizerTesterMixin:
|
||||
[0] * (assistant_start2 - assistant_end - 1),
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_continue_final_message(self):
|
||||
dummy_template = """
|
||||
{%- for message in messages %}
|
||||
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] + "<|im_end|>" + "\n"}}
|
||||
{%- endfor %}"""
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
]
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
|
||||
)
|
||||
self.assertEqual(
|
||||
output,
|
||||
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
|
||||
)
|
||||
prefill_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
|
||||
)
|
||||
# Assert that the final message is unterminated
|
||||
self.assertEqual(
|
||||
prefill_output,
|
||||
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_dict(self):
|
||||
dummy_template_1 = "{{'a'}}"
|
||||
|
||||
Reference in New Issue
Block a user