Fix #34494 assistant tokens when truncated (#34531)

* Fix assistant tokens when truncated

* fix test

* fix test

* step
This commit is contained in:
Yoni Gottesman
2024-11-05 17:10:15 +02:00
committed by GitHub
parent 74d3824cc0
commit 082e57e0d4
8 changed files with 129 additions and 1 deletions

View File

@@ -1327,6 +1327,110 @@ class TokenizerTesterMixin:
[0] * (assistant_start2 - assistant_end - 1),
)
@require_jinja
def test_chat_template_return_assistant_tokens_mask_truncated(self):
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{
"role": "assistant",
"content": (
"start turn assistant. long string to be truncated, long string to be truncated, "
"long string to be truncated, long string to be truncated, long string to be truncated"
),
},
{"role": "user", "content": "another user message"},
],
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{
"role": "assistant",
"content": (
"start turn assistant. long string to be truncated, long string to be truncated, "
"long string to be truncated, long string to be truncated, long string to be truncated"
),
},
{"role": "user", "content": "another user message"},
],
]
for tokenizer, pretrained_name, _ in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
if not self.test_rust_tokenizer:
self.skipTest(reason="No fast tokenizer defined")
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
# Find where to truncate, as the amount of tokens is different for different tokenizers and I want the
# truncation to happen in the middle of the assistant content.
full_encoding = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_dict=True,
)
chat_string = tokenizer_r.apply_chat_template(
conversations[0], tokenize=False, chat_template=dummy_template
)
truncation_position = full_encoding.char_to_token(chat_string.index(", long string to be truncated,"))
# check batched
output = tokenizer_r.apply_chat_template(
conversations,
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
max_length=truncation_position,
truncation=True,
return_dict=True,
)
for i, conv in enumerate(conversations):
chat_string = tokenizer_r.apply_chat_template(conv, tokenize=False, chat_template=dummy_template)
assistant_start = output.char_to_token(i, chat_string.index("start turn assistant"))
# assert 1 from assistant_start to the end because the rest is truncated.
self.assertEqual(
output["assistant_masks"][i][assistant_start:],
[1] * (len(output["assistant_masks"][i]) - assistant_start),
)
# check not batched
output = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
max_length=truncation_position,
truncation=True,
)
chat_string = tokenizer_r.apply_chat_template(
conversations[0], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(0, chat_string.index("start turn assistant"))
# assert 1 from assistant_start to the end because the rest is truncated.
self.assertEqual(
output["assistant_masks"][assistant_start:],
[1] * (len(output["assistant_masks"]) - assistant_start),
)
@require_jinja
def test_continue_final_message(self):
dummy_template = """