Return assistant generated tokens mask in apply_chat_template (#30650)

return assistant generated tokens mask in apply_chat_template
This commit is contained in:
Yoni Gottesman
2024-07-22 20:24:43 +03:00
committed by GitHub
parent 7987710696
commit 74d0eb3fed
8 changed files with 265 additions and 8 deletions

View File

@@ -1153,6 +1153,135 @@ class TokenizerTesterMixin:
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised
@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
],
[
{"role": "system", "content": "system message 3"},
{"role": "user", "content": "user message 3"},
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
{"role": "user", "content": "user message 4"},
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
],
]
# These are the prefix and suffix strings of all the assistant messages. Used to find the assistant substring
# in the entire chat string, and then find the corresponding tokens in the tokenized output.
assistant_prefix_suffix = [
[("start turn 1", "end turn 1<|im_end|>"), ("start turn 2", "end turn 2<|im_end|>")],
[("start turn 3", "end turn 3<|im_end|>"), ("start turn 4", "end turn 4<|im_end|>")],
]
for tokenizer, pretrained_name, _ in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
if not self.test_rust_tokenizer:
self.skipTest(reason="No fast tokenizer defined")
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
# check batched
output = tokenizer_r.apply_chat_template(
conversations,
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
for i, conv in enumerate(conversations):
chat_string = tokenizer_r.apply_chat_template(
conversations[i], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
assistant_end = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][0][1])
+ len(assistant_prefix_suffix[i][0][1])
- 1,
)
assistant_start2 = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][1][0]))
assistant_end2 = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][1][1])
+ len(assistant_prefix_suffix[i][1][1])
- 1,
)
# assert 1 in first assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
# assert 1 second assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
# check not batched
output = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
chat_string = tokenizer_r.apply_chat_template(
conversations[0], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][0][0]))
assistant_end = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][0][1]) + len(assistant_prefix_suffix[0][0][1]) - 1
)
assistant_start2 = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][1][0]))
assistant_end2 = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][1][1]) + len(assistant_prefix_suffix[0][1][1]) - 1
)
# assert 1 in assistant indices
self.assertEqual(
output["assistant_masks"][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertEqual(
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"