Return assistant generated tokens mask in apply_chat_template (#30650)
return assistant generated tokens mask in apply_chat_template
This commit is contained in:
@@ -1697,6 +1697,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
return_dict: bool = False,
|
return_dict: bool = False,
|
||||||
|
return_assistant_tokens_mask: bool = False,
|
||||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
|
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
|
||||||
@@ -1747,6 +1748,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
return_dict (`bool`, defaults to `False`):
|
return_dict (`bool`, defaults to `False`):
|
||||||
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
||||||
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
|
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
|
||||||
|
return_assistant_tokens_mask (`bool`, defaults to `False`):
|
||||||
|
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
|
||||||
|
the mask will contain 1. For user and system tokens, the mask will contain 0.
|
||||||
|
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
|
||||||
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
|
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1761,6 +1766,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
"of tokenizer outputs to return."
|
"of tokenizer outputs to return."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if return_assistant_tokens_mask and not return_dict:
|
||||||
|
raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`")
|
||||||
|
|
||||||
if tokenizer_kwargs is None:
|
if tokenizer_kwargs is None:
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
|
|
||||||
@@ -1813,6 +1821,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
"then to ensure that this model continues working without issues."
|
"then to ensure that this model continues working without issues."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
|
||||||
|
logger.warning_once(
|
||||||
|
"return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword."
|
||||||
|
)
|
||||||
|
|
||||||
# Compilation function uses a cache to avoid recompiling the same template
|
# Compilation function uses a cache to avoid recompiling the same template
|
||||||
compiled_template = self._compile_jinja_template(chat_template)
|
compiled_template = self._compile_jinja_template(chat_template)
|
||||||
|
|
||||||
@@ -1847,18 +1860,30 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
|
raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
|
||||||
|
|
||||||
rendered = []
|
rendered = []
|
||||||
|
all_generation_indices = []
|
||||||
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
|
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
|
||||||
for chat in conversations:
|
for chat in conversations:
|
||||||
if hasattr(chat, "messages"):
|
if hasattr(chat, "messages"):
|
||||||
# Indicates it's a Conversation object
|
# Indicates it's a Conversation object
|
||||||
chat = chat.messages
|
chat = chat.messages
|
||||||
rendered_chat = compiled_template.render(
|
if return_assistant_tokens_mask:
|
||||||
messages=chat,
|
rendered_chat, generation_indices = self._render_with_assistant_indices(
|
||||||
tools=tool_schemas,
|
compiled_template=compiled_template,
|
||||||
documents=documents,
|
messages=chat,
|
||||||
add_generation_prompt=add_generation_prompt,
|
tools=tool_schemas,
|
||||||
**template_kwargs,
|
documents=documents,
|
||||||
)
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
**template_kwargs,
|
||||||
|
)
|
||||||
|
all_generation_indices.append(generation_indices)
|
||||||
|
else:
|
||||||
|
rendered_chat = compiled_template.render(
|
||||||
|
messages=chat,
|
||||||
|
tools=tool_schemas,
|
||||||
|
documents=documents,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
**template_kwargs,
|
||||||
|
)
|
||||||
rendered.append(rendered_chat)
|
rendered.append(rendered_chat)
|
||||||
|
|
||||||
if not is_batched:
|
if not is_batched:
|
||||||
@@ -1875,17 +1900,54 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
**tokenizer_kwargs,
|
**tokenizer_kwargs,
|
||||||
)
|
)
|
||||||
if return_dict:
|
if return_dict:
|
||||||
|
if return_assistant_tokens_mask:
|
||||||
|
assistant_masks = []
|
||||||
|
if is_batched or return_tensors:
|
||||||
|
input_ids = out["input_ids"]
|
||||||
|
else:
|
||||||
|
input_ids = [out["input_ids"]]
|
||||||
|
for i in range(len(input_ids)):
|
||||||
|
current_mask = [0] * len(input_ids[i])
|
||||||
|
for assistant_start_char, assistant_end_char in all_generation_indices[i]:
|
||||||
|
start_token = out.char_to_token(i, assistant_start_char)
|
||||||
|
end_token = out.char_to_token(i, assistant_end_char - 1)
|
||||||
|
if start_token is None:
|
||||||
|
# start_token is out of bounds maybe due to truncation.
|
||||||
|
break
|
||||||
|
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids)):
|
||||||
|
current_mask[token_id] = 1
|
||||||
|
assistant_masks.append(current_mask)
|
||||||
|
out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0]
|
||||||
return out
|
return out
|
||||||
else:
|
else:
|
||||||
return out["input_ids"]
|
return out["input_ids"]
|
||||||
else:
|
else:
|
||||||
return rendered
|
return rendered
|
||||||
|
|
||||||
|
def _render_with_assistant_indices(
|
||||||
|
self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
|
||||||
|
):
|
||||||
|
rendered_blocks = []
|
||||||
|
generation_indices = []
|
||||||
|
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
|
||||||
|
for block in compiled_template.generate(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
documents=documents,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
**template_kwargs,
|
||||||
|
):
|
||||||
|
rendered_blocks.append(block)
|
||||||
|
rendered_chat = "".join(rendered_blocks)
|
||||||
|
return rendered_chat, generation_indices
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def _compile_jinja_template(self, chat_template):
|
def _compile_jinja_template(self, chat_template):
|
||||||
try:
|
try:
|
||||||
import jinja2
|
import jinja2
|
||||||
|
from jinja2 import nodes
|
||||||
from jinja2.exceptions import TemplateError
|
from jinja2.exceptions import TemplateError
|
||||||
|
from jinja2.ext import Extension
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("apply_chat_template requires jinja2 to be installed.")
|
raise ImportError("apply_chat_template requires jinja2 to be installed.")
|
||||||
@@ -1903,7 +1965,49 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
# We also expose some options like custom indents and separators
|
# We also expose some options like custom indents and separators
|
||||||
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
||||||
|
|
||||||
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
class AssistantTracker(Extension):
|
||||||
|
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
|
||||||
|
tags = {"generation"}
|
||||||
|
|
||||||
|
def __init__(self, environment: ImmutableSandboxedEnvironment):
|
||||||
|
# The class is only initiated by jinja.
|
||||||
|
super().__init__(environment)
|
||||||
|
environment.extend(activate_tracker=self.activate_tracker)
|
||||||
|
self._rendered_blocks = None
|
||||||
|
self._generation_indices = None
|
||||||
|
|
||||||
|
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
|
||||||
|
lineno = next(parser.stream).lineno
|
||||||
|
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
|
||||||
|
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
|
||||||
|
|
||||||
|
@jinja2.pass_eval_context
|
||||||
|
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
|
||||||
|
rv = caller()
|
||||||
|
if self.is_active():
|
||||||
|
# Only track generation indices if the tracker is active
|
||||||
|
start_index = len("".join(self._rendered_blocks))
|
||||||
|
end_index = start_index + len(rv)
|
||||||
|
self._generation_indices.append((start_index, end_index))
|
||||||
|
return rv
|
||||||
|
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
return self._rendered_blocks or self._generation_indices
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
|
||||||
|
try:
|
||||||
|
if self.is_active():
|
||||||
|
raise ValueError("AssistantTracker should not be reused before closed")
|
||||||
|
self._rendered_blocks = rendered_blocks
|
||||||
|
self._generation_indices = generation_indices
|
||||||
|
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._rendered_blocks = None
|
||||||
|
self._generation_indices = None
|
||||||
|
|
||||||
|
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker])
|
||||||
jinja_env.filters["tojson"] = tojson
|
jinja_env.filters["tojson"] = tojson
|
||||||
jinja_env.globals["raise_exception"] = raise_exception
|
jinja_env.globals["raise_exception"] = raise_exception
|
||||||
return jinja_env.from_string(chat_template)
|
return jinja_env.from_string(chat_template)
|
||||||
|
|||||||
@@ -2483,3 +2483,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip(reason="Chat is not supported")
|
@unittest.skip(reason="Chat is not supported")
|
||||||
def test_chat_template(self):
|
def test_chat_template(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -2436,3 +2436,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip(reason="Chat is not supported")
|
@unittest.skip(reason="Chat is not supported")
|
||||||
def test_chat_template(self):
|
def test_chat_template(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1977,3 +1977,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip(reason="Chat is not supported")
|
@unittest.skip(reason="Chat is not supported")
|
||||||
def test_chat_template(self):
|
def test_chat_template(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -2316,3 +2316,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip(reason="The model tested fails `Hub -> Fast == Hub -> Slow`, nothing much we can do")
|
@unittest.skip(reason="The model tested fails `Hub -> Fast == Hub -> Slow`, nothing much we can do")
|
||||||
def test_added_tokens_serialization(self):
|
def test_added_tokens_serialization(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1277,3 +1277,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip(reason="Chat is not supported")
|
@unittest.skip(reason="Chat is not supported")
|
||||||
def test_chat_template(self):
|
def test_chat_template(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1157,6 +1157,10 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def test_chat_template(self):
|
def test_chat_template(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
|
@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
|
||||||
def test_chat_template_batched(self):
|
def test_chat_template_batched(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1153,6 +1153,135 @@ class TokenizerTesterMixin:
|
|||||||
dummy_conversations, chat_template=dummy_template, tokenize=True
|
dummy_conversations, chat_template=dummy_template, tokenize=True
|
||||||
) # Check that no error raised
|
) # Check that no error raised
|
||||||
|
|
||||||
|
@require_jinja
|
||||||
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
|
dummy_template = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{% if (message['role'] != 'assistant') %}"
|
||||||
|
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
||||||
|
"{% elif (message['role'] == 'assistant')%}"
|
||||||
|
"{{'<|im_start|>' + message['role'] + '\n'}}"
|
||||||
|
"{% generation %}"
|
||||||
|
"{{message['content'] + '<|im_end|>'}}"
|
||||||
|
"{% endgeneration %}"
|
||||||
|
"{{'\n'}}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
)
|
||||||
|
conversations = [
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "system message"},
|
||||||
|
{"role": "user", "content": "user message"},
|
||||||
|
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
|
||||||
|
{"role": "user", "content": "user message 2"},
|
||||||
|
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "system message 3"},
|
||||||
|
{"role": "user", "content": "user message 3"},
|
||||||
|
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
|
||||||
|
{"role": "user", "content": "user message 4"},
|
||||||
|
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
# These are the prefix and suffix strings of all the assistant messages. Used to find the assistant substring
|
||||||
|
# in the entire chat string, and then find the corresponding tokens in the tokenized output.
|
||||||
|
assistant_prefix_suffix = [
|
||||||
|
[("start turn 1", "end turn 1<|im_end|>"), ("start turn 2", "end turn 2<|im_end|>")],
|
||||||
|
[("start turn 3", "end turn 3<|im_end|>"), ("start turn 4", "end turn 4<|im_end|>")],
|
||||||
|
]
|
||||||
|
for tokenizer, pretrained_name, _ in self.tokenizers_list:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||||
|
if not self.test_rust_tokenizer:
|
||||||
|
self.skipTest(reason="No fast tokenizer defined")
|
||||||
|
|
||||||
|
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
|
||||||
|
|
||||||
|
# check batched
|
||||||
|
output = tokenizer_r.apply_chat_template(
|
||||||
|
conversations,
|
||||||
|
chat_template=dummy_template,
|
||||||
|
tokenize=True,
|
||||||
|
return_assistant_tokens_mask=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
for i, conv in enumerate(conversations):
|
||||||
|
chat_string = tokenizer_r.apply_chat_template(
|
||||||
|
conversations[i], tokenize=False, chat_template=dummy_template
|
||||||
|
)
|
||||||
|
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
|
||||||
|
assistant_end = output.char_to_token(
|
||||||
|
i,
|
||||||
|
chat_string.index(assistant_prefix_suffix[i][0][1])
|
||||||
|
+ len(assistant_prefix_suffix[i][0][1])
|
||||||
|
- 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant_start2 = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][1][0]))
|
||||||
|
assistant_end2 = output.char_to_token(
|
||||||
|
i,
|
||||||
|
chat_string.index(assistant_prefix_suffix[i][1][1])
|
||||||
|
+ len(assistant_prefix_suffix[i][1][1])
|
||||||
|
- 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert 1 in first assistant message
|
||||||
|
self.assertEqual(
|
||||||
|
output["assistant_masks"][i][assistant_start : assistant_end + 1],
|
||||||
|
[1] * (assistant_end - assistant_start + 1),
|
||||||
|
)
|
||||||
|
# assert 1 second assistant message
|
||||||
|
self.assertEqual(
|
||||||
|
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
|
||||||
|
[1] * (assistant_end2 - assistant_start2 + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert 0 in user/system indices
|
||||||
|
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
|
||||||
|
self.assertEqual(
|
||||||
|
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
|
||||||
|
[0] * (assistant_start2 - assistant_end - 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check not batched
|
||||||
|
output = tokenizer_r.apply_chat_template(
|
||||||
|
conversations[0],
|
||||||
|
chat_template=dummy_template,
|
||||||
|
tokenize=True,
|
||||||
|
return_assistant_tokens_mask=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_string = tokenizer_r.apply_chat_template(
|
||||||
|
conversations[0], tokenize=False, chat_template=dummy_template
|
||||||
|
)
|
||||||
|
assistant_start = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][0][0]))
|
||||||
|
assistant_end = output.char_to_token(
|
||||||
|
0, chat_string.index(assistant_prefix_suffix[0][0][1]) + len(assistant_prefix_suffix[0][0][1]) - 1
|
||||||
|
)
|
||||||
|
assistant_start2 = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][1][0]))
|
||||||
|
assistant_end2 = output.char_to_token(
|
||||||
|
0, chat_string.index(assistant_prefix_suffix[0][1][1]) + len(assistant_prefix_suffix[0][1][1]) - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert 1 in assistant indices
|
||||||
|
self.assertEqual(
|
||||||
|
output["assistant_masks"][assistant_start : assistant_end + 1],
|
||||||
|
[1] * (assistant_end - assistant_start + 1),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
|
||||||
|
[1] * (assistant_end2 - assistant_start2 + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert 0 in user/system indices
|
||||||
|
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
|
||||||
|
self.assertEqual(
|
||||||
|
output["assistant_masks"][assistant_end + 1 : assistant_start2],
|
||||||
|
[0] * (assistant_start2 - assistant_end - 1),
|
||||||
|
)
|
||||||
|
|
||||||
@require_jinja
|
@require_jinja
|
||||||
def test_chat_template_dict(self):
|
def test_chat_template_dict(self):
|
||||||
dummy_template_1 = "{{'a'}}"
|
dummy_template_1 = "{{'a'}}"
|
||||||
|
|||||||
Reference in New Issue
Block a user