Add assistant prefill for chat templates and TextGenerationPipeline (#33198)
* Add assistant prefill to chat templates * Add assistant prefill to pipeline * Add assistant prefill to pipeline * Tweak another test that ended in assistant message * Update tests that ended in assistant messages * Update tests that ended in assistant messages * Replace assistant_prefill with continue_final_message * Allow passing continue_final_message to pipeline * Small fixup * Add continue_final_message as a pipeline kwarg * Update docstrings * Move repos to hf-internal-testing! * Update src/transformers/tokenization_utils_base.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Add explanatory comment * make fixup * Update chat templating docs to explain continue_last_message --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -1327,6 +1327,36 @@ class TokenizerTesterMixin:
|
||||
[0] * (assistant_start2 - assistant_end - 1),
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_continue_final_message(self):
|
||||
dummy_template = """
|
||||
{%- for message in messages %}
|
||||
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] + "<|im_end|>" + "\n"}}
|
||||
{%- endfor %}"""
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
]
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
|
||||
)
|
||||
self.assertEqual(
|
||||
output,
|
||||
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
|
||||
)
|
||||
prefill_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
|
||||
)
|
||||
# Assert that the final message is unterminated
|
||||
self.assertEqual(
|
||||
prefill_output,
|
||||
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_dict(self):
|
||||
dummy_template_1 = "{{'a'}}"
|
||||
|
||||
Reference in New Issue
Block a user