Retain newlines in chat template when continue_final_message=True (#34253)

* Retain newlines in chat template when

* Add try/except

* Add regression test

* Simplify test

* Apply suggestions from code review

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

---------

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
lewtun
2024-11-15 15:27:04 +01:00
committed by GitHub
parent a3d69a8994
commit 8ba3e1505e
2 changed files with 38 additions and 2 deletions

View File

@@ -1461,6 +1461,38 @@ class TokenizerTesterMixin:
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
)
@require_jinja
def test_continue_final_message_with_trim(self):
"""Regression test for chat templates with trimming: https://github.com/huggingface/transformers/pull/34214"""
dummy_template = """
{%- for message in messages %}
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] | trim + "<|im_end|>" + "\n"}}
{%- endfor %}"""
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message "}, # Note the trailing whitespace
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
)
self.assertEqual(
output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
)
prefill_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
)
# Assert that the final message is unterminated
self.assertEqual(
prefill_output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
)
@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"