Enable some Jinja extensions and add datetime capabilities (#32684)

* Add new Jinja features:

- Do extension
- Break/continue in loops
- Call strftime to get current datetime in any format

* Add new Jinja features:

- Do extension
- Break/continue in loops
- Call strftime to get current datetime in any format

* Fix strftime template

* Add template strip() just to be safe

* Remove the do extension to make porting easier, and also because it's the least useful

* Rename test

* strftime -> strftime_now

* Split test

* Update test to use strftime_now

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils
This commit is contained in:
Matt
2024-08-23 14:26:12 +01:00
committed by GitHub
parent adb91179b9
commit 371b9c1486
3 changed files with 151 additions and 92 deletions

View File

@@ -27,7 +27,6 @@ from collections import UserDict
from collections.abc import Mapping, Sized from collections.abc import Mapping, Sized
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache
from inspect import isfunction from inspect import isfunction
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
@@ -65,6 +64,7 @@ from .utils import (
requires_backends, requires_backends,
to_py_obj, to_py_obj,
) )
from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -1791,7 +1791,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
) )
# Compilation function uses a cache to avoid recompiling the same template # Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template) compiled_template = _compile_jinja_template(chat_template)
if isinstance(conversation, (list, tuple)) and ( if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
@@ -1831,7 +1831,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Indicates it's a Conversation object # Indicates it's a Conversation object
chat = chat.messages chat = chat.messages
if return_assistant_tokens_mask: if return_assistant_tokens_mask:
rendered_chat, generation_indices = self._render_with_assistant_indices( rendered_chat, generation_indices = _render_with_assistant_indices(
compiled_template=compiled_template, compiled_template=compiled_template,
messages=chat, messages=chat,
tools=tool_schemas, tools=tool_schemas,
@@ -1888,94 +1888,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
else: else:
return rendered return rendered
def _render_with_assistant_indices(
self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices
@lru_cache
def _compile_jinja_template(self, chat_template):
try:
import jinja2
from jinja2 import nodes
from jinja2.exceptions import TemplateError
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)
def raise_exception(message):
raise TemplateError(message)
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}
def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv
def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices
@contextmanager
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices
yield
finally:
self._rendered_blocks = None
self._generation_indices = None
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker])
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)
def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str: def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
""" """
Retrieve the chat template string used for tokenizing chat messages. This template is used Retrieve the chat template string used for tokenizing chat messages. This template is used

View File

@@ -15,7 +15,22 @@
import inspect import inspect
import json import json
import re import re
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints
from packaging import version
from .import_utils import is_jinja_available
if is_jinja_available():
import jinja2
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
else:
jinja2 = None
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
@@ -314,3 +329,90 @@ def get_json_schema(func: Callable) -> Dict:
if return_dict is not None: if return_dict is not None:
output["return"] = return_dict output["return"] = return_dict
return {"type": "function", "function": output} return {"type": "function", "function": output}
def _render_with_assistant_indices(
compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices
@lru_cache
def _compile_jinja_template(chat_template):
class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}
def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv
def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices
@contextmanager
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices
yield
finally:
self._rendered_blocks = None
self._generation_indices = None
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)
def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
def strftime_now(format):
return datetime.now().strftime(format)
jinja_env = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
)
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["strftime_now"] = strftime_now
return jinja_env.from_string(chat_template)

View File

@@ -1153,6 +1153,51 @@ class TokenizerTesterMixin:
dummy_conversations, chat_template=dummy_template, tokenize=True dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised ) # Check that no error raised
@require_jinja
def test_jinja_loopcontrols(self):
break_template = """
{%- for message in messages %}
{{- message.role + " " + message.content }}
{%- if loop.first %}
{%- break %}
{%- endif %}
{%- endfor %}""".strip()
dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
break_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=break_template, tokenize=False
)
self.assertEqual(break_output, "system 1") # Loop should break after first iter
@require_jinja
def test_jinja_strftime(self):
strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip()
dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
strftime_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=strftime_template, tokenize=False
)
# Assert that we get a date formatted as expected
self.assertEqual(len(strftime_output), 10)
self.assertEqual(len(strftime_output.split("-")), 3)
@require_jinja @require_jinja
def test_chat_template_return_assistant_tokens_mask(self): def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = ( dummy_template = (