From 8ba3e1505ec30b7086a7a523355a210395b6cf70 Mon Sep 17 00:00:00 2001 From: lewtun Date: Fri, 15 Nov 2024 15:27:04 +0100 Subject: [PATCH] Retain newlines in chat template when `continue_final_message=True` (#34253) * Retain newlines in chat template when * Add try/except * Add regression test * Simplify test * Apply suggestions from code review Co-authored-by: Matt --------- Co-authored-by: Matt --- src/transformers/tokenization_utils_base.py | 8 ++++-- tests/test_tokenization_common.py | 32 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 381f3ef497..03df02d21f 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1690,8 +1690,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): final_message = chat[-1]["content"] if isinstance(final_message, (list, tuple)): final_message = final_message[-1]["text"] - final_message = final_message.strip() - rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip() + try: + rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)] + except: # noqa: E722 + # Some chat templates like Llama-3.1 trim messages before rendering, so we must do the same here. + final_message = final_message.strip() + rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)] rendered.append(rendered_chat) if not is_batched: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index a3bbbf3c9e..f04a425555 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1461,6 +1461,38 @@ class TokenizerTesterMixin: "<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message", ) + @require_jinja + def test_continue_final_message_with_trim(self): + """Regression test for chat templates with trimming: https://github.com/huggingface/transformers/pull/34214""" + + dummy_template = """ + {%- for message in messages %} + {{- "<|im_start|>" + message['role'] + "\n" + message['content'] | trim + "<|im_end|>" + "\n"}} + {%- endfor %}""" + dummy_conversation = [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "assistant message "}, # Note the trailing whitespace + ] + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False + ) + self.assertEqual( + output, + "<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n", + ) + prefill_output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True + ) + # Assert that the final message is unterminated + self.assertEqual( + prefill_output, + "<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message", + ) + @require_jinja def test_chat_template_dict(self): dummy_template_1 = "{{'a'}}"