Retain newlines in chat template when continue_final_message=True (#34253)
* Retain newlines in chat template when * Add try/except * Add regression test * Simplify test * Apply suggestions from code review Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> --------- Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -1690,8 +1690,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
final_message = chat[-1]["content"]
|
final_message = chat[-1]["content"]
|
||||||
if isinstance(final_message, (list, tuple)):
|
if isinstance(final_message, (list, tuple)):
|
||||||
final_message = final_message[-1]["text"]
|
final_message = final_message[-1]["text"]
|
||||||
|
try:
|
||||||
|
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)]
|
||||||
|
except: # noqa: E722
|
||||||
|
# Some chat templates like Llama-3.1 trim messages before rendering, so we must do the same here.
|
||||||
final_message = final_message.strip()
|
final_message = final_message.strip()
|
||||||
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip()
|
rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)]
|
||||||
rendered.append(rendered_chat)
|
rendered.append(rendered_chat)
|
||||||
|
|
||||||
if not is_batched:
|
if not is_batched:
|
||||||
|
|||||||
@@ -1461,6 +1461,38 @@ class TokenizerTesterMixin:
|
|||||||
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
|
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_jinja
|
||||||
|
def test_continue_final_message_with_trim(self):
|
||||||
|
"""Regression test for chat templates with trimming: https://github.com/huggingface/transformers/pull/34214"""
|
||||||
|
|
||||||
|
dummy_template = """
|
||||||
|
{%- for message in messages %}
|
||||||
|
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] | trim + "<|im_end|>" + "\n"}}
|
||||||
|
{%- endfor %}"""
|
||||||
|
dummy_conversation = [
|
||||||
|
{"role": "system", "content": "system message"},
|
||||||
|
{"role": "user", "content": "user message"},
|
||||||
|
{"role": "assistant", "content": "assistant message "}, # Note the trailing whitespace
|
||||||
|
]
|
||||||
|
tokenizers = self.get_tokenizers()
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
output = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
|
||||||
|
)
|
||||||
|
prefill_output = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
|
||||||
|
)
|
||||||
|
# Assert that the final message is unterminated
|
||||||
|
self.assertEqual(
|
||||||
|
prefill_output,
|
||||||
|
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
|
||||||
|
)
|
||||||
|
|
||||||
@require_jinja
|
@require_jinja
|
||||||
def test_chat_template_dict(self):
|
def test_chat_template_dict(self):
|
||||||
dummy_template_1 = "{{'a'}}"
|
dummy_template_1 = "{{'a'}}"
|
||||||
|
|||||||
Reference in New Issue
Block a user