Enable some Jinja extensions and add datetime capabilities (#32684)
* Add new Jinja features: - Do extension - Break/continue in loops - Call strftime to get current datetime in any format * Add new Jinja features: - Do extension - Break/continue in loops - Call strftime to get current datetime in any format * Fix strftime template * Add template strip() just to be safe * Remove the do extension to make porting easier, and also because it's the least useful * Rename test * strftime -> strftime_now * Split test * Update test to use strftime_now * Refactor everything out into chat_template_utils * Refactor everything out into chat_template_utils * Refactor everything out into chat_template_utils * Refactor everything out into chat_template_utils * Refactor everything out into chat_template_utils
This commit is contained in:
@@ -27,7 +27,6 @@ from collections import UserDict
|
|||||||
from collections.abc import Mapping, Sized
|
from collections.abc import Mapping, Sized
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
@@ -65,6 +64,7 @@ from .utils import (
|
|||||||
requires_backends,
|
requires_backends,
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
)
|
)
|
||||||
|
from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -1791,7 +1791,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Compilation function uses a cache to avoid recompiling the same template
|
# Compilation function uses a cache to avoid recompiling the same template
|
||||||
compiled_template = self._compile_jinja_template(chat_template)
|
compiled_template = _compile_jinja_template(chat_template)
|
||||||
|
|
||||||
if isinstance(conversation, (list, tuple)) and (
|
if isinstance(conversation, (list, tuple)) and (
|
||||||
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
|
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
|
||||||
@@ -1831,7 +1831,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
# Indicates it's a Conversation object
|
# Indicates it's a Conversation object
|
||||||
chat = chat.messages
|
chat = chat.messages
|
||||||
if return_assistant_tokens_mask:
|
if return_assistant_tokens_mask:
|
||||||
rendered_chat, generation_indices = self._render_with_assistant_indices(
|
rendered_chat, generation_indices = _render_with_assistant_indices(
|
||||||
compiled_template=compiled_template,
|
compiled_template=compiled_template,
|
||||||
messages=chat,
|
messages=chat,
|
||||||
tools=tool_schemas,
|
tools=tool_schemas,
|
||||||
@@ -1888,94 +1888,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
return rendered
|
return rendered
|
||||||
|
|
||||||
def _render_with_assistant_indices(
|
|
||||||
self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
|
|
||||||
):
|
|
||||||
rendered_blocks = []
|
|
||||||
generation_indices = []
|
|
||||||
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
|
|
||||||
for block in compiled_template.generate(
|
|
||||||
messages=messages,
|
|
||||||
tools=tools,
|
|
||||||
documents=documents,
|
|
||||||
add_generation_prompt=add_generation_prompt,
|
|
||||||
**template_kwargs,
|
|
||||||
):
|
|
||||||
rendered_blocks.append(block)
|
|
||||||
rendered_chat = "".join(rendered_blocks)
|
|
||||||
return rendered_chat, generation_indices
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def _compile_jinja_template(self, chat_template):
|
|
||||||
try:
|
|
||||||
import jinja2
|
|
||||||
from jinja2 import nodes
|
|
||||||
from jinja2.exceptions import TemplateError
|
|
||||||
from jinja2.ext import Extension
|
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("apply_chat_template requires jinja2 to be installed.")
|
|
||||||
|
|
||||||
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
|
||||||
raise ImportError(
|
|
||||||
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
|
|
||||||
)
|
|
||||||
|
|
||||||
def raise_exception(message):
|
|
||||||
raise TemplateError(message)
|
|
||||||
|
|
||||||
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
|
|
||||||
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
|
|
||||||
# We also expose some options like custom indents and separators
|
|
||||||
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
|
||||||
|
|
||||||
class AssistantTracker(Extension):
|
|
||||||
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
|
|
||||||
tags = {"generation"}
|
|
||||||
|
|
||||||
def __init__(self, environment: ImmutableSandboxedEnvironment):
|
|
||||||
# The class is only initiated by jinja.
|
|
||||||
super().__init__(environment)
|
|
||||||
environment.extend(activate_tracker=self.activate_tracker)
|
|
||||||
self._rendered_blocks = None
|
|
||||||
self._generation_indices = None
|
|
||||||
|
|
||||||
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
|
|
||||||
lineno = next(parser.stream).lineno
|
|
||||||
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
|
|
||||||
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
|
|
||||||
|
|
||||||
@jinja2.pass_eval_context
|
|
||||||
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
|
|
||||||
rv = caller()
|
|
||||||
if self.is_active():
|
|
||||||
# Only track generation indices if the tracker is active
|
|
||||||
start_index = len("".join(self._rendered_blocks))
|
|
||||||
end_index = start_index + len(rv)
|
|
||||||
self._generation_indices.append((start_index, end_index))
|
|
||||||
return rv
|
|
||||||
|
|
||||||
def is_active(self) -> bool:
|
|
||||||
return self._rendered_blocks or self._generation_indices
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
|
|
||||||
try:
|
|
||||||
if self.is_active():
|
|
||||||
raise ValueError("AssistantTracker should not be reused before closed")
|
|
||||||
self._rendered_blocks = rendered_blocks
|
|
||||||
self._generation_indices = generation_indices
|
|
||||||
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
self._rendered_blocks = None
|
|
||||||
self._generation_indices = None
|
|
||||||
|
|
||||||
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker])
|
|
||||||
jinja_env.filters["tojson"] = tojson
|
|
||||||
jinja_env.globals["raise_exception"] = raise_exception
|
|
||||||
return jinja_env.from_string(chat_template)
|
|
||||||
|
|
||||||
def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
|
def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Retrieve the chat template string used for tokenizing chat messages. This template is used
|
Retrieve the chat template string used for tokenizing chat messages. This template is used
|
||||||
|
|||||||
@@ -15,7 +15,22 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from .import_utils import is_jinja_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_jinja_available():
|
||||||
|
import jinja2
|
||||||
|
from jinja2.ext import Extension
|
||||||
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
|
else:
|
||||||
|
jinja2 = None
|
||||||
|
|
||||||
|
|
||||||
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
|
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
|
||||||
@@ -314,3 +329,90 @@ def get_json_schema(func: Callable) -> Dict:
|
|||||||
if return_dict is not None:
|
if return_dict is not None:
|
||||||
output["return"] = return_dict
|
output["return"] = return_dict
|
||||||
return {"type": "function", "function": output}
|
return {"type": "function", "function": output}
|
||||||
|
|
||||||
|
|
||||||
|
def _render_with_assistant_indices(
|
||||||
|
compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
|
||||||
|
):
|
||||||
|
rendered_blocks = []
|
||||||
|
generation_indices = []
|
||||||
|
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
|
||||||
|
for block in compiled_template.generate(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
documents=documents,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
**template_kwargs,
|
||||||
|
):
|
||||||
|
rendered_blocks.append(block)
|
||||||
|
rendered_chat = "".join(rendered_blocks)
|
||||||
|
return rendered_chat, generation_indices
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _compile_jinja_template(chat_template):
|
||||||
|
class AssistantTracker(Extension):
|
||||||
|
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
|
||||||
|
tags = {"generation"}
|
||||||
|
|
||||||
|
def __init__(self, environment: ImmutableSandboxedEnvironment):
|
||||||
|
# The class is only initiated by jinja.
|
||||||
|
super().__init__(environment)
|
||||||
|
environment.extend(activate_tracker=self.activate_tracker)
|
||||||
|
self._rendered_blocks = None
|
||||||
|
self._generation_indices = None
|
||||||
|
|
||||||
|
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
|
||||||
|
lineno = next(parser.stream).lineno
|
||||||
|
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
|
||||||
|
return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
|
||||||
|
|
||||||
|
@jinja2.pass_eval_context
|
||||||
|
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
|
||||||
|
rv = caller()
|
||||||
|
if self.is_active():
|
||||||
|
# Only track generation indices if the tracker is active
|
||||||
|
start_index = len("".join(self._rendered_blocks))
|
||||||
|
end_index = start_index + len(rv)
|
||||||
|
self._generation_indices.append((start_index, end_index))
|
||||||
|
return rv
|
||||||
|
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
return self._rendered_blocks or self._generation_indices
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
|
||||||
|
try:
|
||||||
|
if self.is_active():
|
||||||
|
raise ValueError("AssistantTracker should not be reused before closed")
|
||||||
|
self._rendered_blocks = rendered_blocks
|
||||||
|
self._generation_indices = generation_indices
|
||||||
|
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._rendered_blocks = None
|
||||||
|
self._generation_indices = None
|
||||||
|
|
||||||
|
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
||||||
|
raise ImportError(
|
||||||
|
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def raise_exception(message):
|
||||||
|
raise jinja2.exceptions.TemplateError(message)
|
||||||
|
|
||||||
|
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
|
||||||
|
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
|
||||||
|
# We also expose some options like custom indents and separators
|
||||||
|
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
||||||
|
|
||||||
|
def strftime_now(format):
|
||||||
|
return datetime.now().strftime(format)
|
||||||
|
|
||||||
|
jinja_env = ImmutableSandboxedEnvironment(
|
||||||
|
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
|
||||||
|
)
|
||||||
|
jinja_env.filters["tojson"] = tojson
|
||||||
|
jinja_env.globals["raise_exception"] = raise_exception
|
||||||
|
jinja_env.globals["strftime_now"] = strftime_now
|
||||||
|
return jinja_env.from_string(chat_template)
|
||||||
|
|||||||
@@ -1153,6 +1153,51 @@ class TokenizerTesterMixin:
|
|||||||
dummy_conversations, chat_template=dummy_template, tokenize=True
|
dummy_conversations, chat_template=dummy_template, tokenize=True
|
||||||
) # Check that no error raised
|
) # Check that no error raised
|
||||||
|
|
||||||
|
@require_jinja
|
||||||
|
def test_jinja_loopcontrols(self):
|
||||||
|
break_template = """
|
||||||
|
{%- for message in messages %}
|
||||||
|
{{- message.role + " " + message.content }}
|
||||||
|
{%- if loop.first %}
|
||||||
|
{%- break %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}""".strip()
|
||||||
|
|
||||||
|
dummy_conversation = [
|
||||||
|
{"role": "system", "content": "1"},
|
||||||
|
{"role": "user", "content": "2"},
|
||||||
|
{"role": "assistant", "content": "3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
tokenizers = self.get_tokenizers()
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
break_output = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template=break_template, tokenize=False
|
||||||
|
)
|
||||||
|
self.assertEqual(break_output, "system 1") # Loop should break after first iter
|
||||||
|
|
||||||
|
@require_jinja
|
||||||
|
def test_jinja_strftime(self):
|
||||||
|
strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip()
|
||||||
|
|
||||||
|
dummy_conversation = [
|
||||||
|
{"role": "system", "content": "1"},
|
||||||
|
{"role": "user", "content": "2"},
|
||||||
|
{"role": "assistant", "content": "3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
tokenizers = self.get_tokenizers()
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
strftime_output = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template=strftime_template, tokenize=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that we get a date formatted as expected
|
||||||
|
self.assertEqual(len(strftime_output), 10)
|
||||||
|
self.assertEqual(len(strftime_output.split("-")), 3)
|
||||||
|
|
||||||
@require_jinja
|
@require_jinja
|
||||||
def test_chat_template_return_assistant_tokens_mask(self):
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
dummy_template = (
|
dummy_template = (
|
||||||
|
|||||||
Reference in New Issue
Block a user