Add assistant prefill for chat templates and TextGenerationPipeline (#33198)

* Add assistant prefill to chat templates

* Add assistant prefill to pipeline

* Add assistant prefill to pipeline

* Tweak another test that ended in assistant message

* Update tests that ended in assistant messages

* Update tests that ended in assistant messages

* Replace assistant_prefill with continue_final_message

* Allow passing continue_final_message to pipeline

* Small fixup

* Add continue_final_message as a pipeline kwarg

* Update docstrings

* Move repos to hf-internal-testing!

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Add explanatory comment

* make fixup

* Update chat templating docs to explain continue_last_message

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Matt
2024-09-02 13:23:47 +01:00
committed by GitHub
parent 2d37085817
commit 52a0213755
6 changed files with 199 additions and 23 deletions

View File

@@ -1327,6 +1327,36 @@ class TokenizerTesterMixin:
[0] * (assistant_start2 - assistant_end - 1),
)
@require_jinja
def test_continue_final_message(self):
dummy_template = """
{%- for message in messages %}
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] + "<|im_end|>" + "\n"}}
{%- endfor %}"""
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
)
self.assertEqual(
output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
)
prefill_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
)
# Assert that the final message is unterminated
self.assertEqual(
prefill_output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
)
@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"