Enable some Jinja extensions and add datetime capabilities (#32684)

* Add new Jinja features:

- Do extension
- Break/continue in loops
- Call strftime to get current datetime in any format

* Add new Jinja features:

- Do extension
- Break/continue in loops
- Call strftime to get current datetime in any format

* Fix strftime template

* Add template strip() just to be safe

* Remove the do extension to make porting easier, and also because it's the least useful

* Rename test

* strftime -> strftime_now

* Split test

* Update test to use strftime_now

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils
This commit is contained in:
Matt
2024-08-23 14:26:12 +01:00
committed by GitHub
parent adb91179b9
commit 371b9c1486
3 changed files with 151 additions and 92 deletions

View File

@@ -1153,6 +1153,51 @@ class TokenizerTesterMixin:
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised
@require_jinja
def test_jinja_loopcontrols(self):
break_template = """
{%- for message in messages %}
{{- message.role + " " + message.content }}
{%- if loop.first %}
{%- break %}
{%- endif %}
{%- endfor %}""".strip()
dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
break_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=break_template, tokenize=False
)
self.assertEqual(break_output, "system 1") # Loop should break after first iter
@require_jinja
def test_jinja_strftime(self):
strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip()
dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
strftime_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=strftime_template, tokenize=False
)
# Assert that we get a date formatted as expected
self.assertEqual(len(strftime_output), 10)
self.assertEqual(len(strftime_output.split("-")), 3)
@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (