* Fix assistant tokens when truncated * fix test * fix test * step
This commit is contained in:
@@ -1327,6 +1327,110 @@ class TokenizerTesterMixin:
|
||||
[0] * (assistant_start2 - assistant_end - 1),
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_return_assistant_tokens_mask_truncated(self):
|
||||
dummy_template = (
|
||||
"{% for message in messages %}"
|
||||
"{% if (message['role'] != 'assistant') %}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
||||
"{% elif (message['role'] == 'assistant')%}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n'}}"
|
||||
"{% generation %}"
|
||||
"{{message['content'] + '<|im_end|>'}}"
|
||||
"{% endgeneration %}"
|
||||
"{{'\n'}}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
conversations = [
|
||||
[
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"start turn assistant. long string to be truncated, long string to be truncated, "
|
||||
"long string to be truncated, long string to be truncated, long string to be truncated"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "another user message"},
|
||||
],
|
||||
[
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"start turn assistant. long string to be truncated, long string to be truncated, "
|
||||
"long string to be truncated, long string to be truncated, long string to be truncated"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "another user message"},
|
||||
],
|
||||
]
|
||||
|
||||
for tokenizer, pretrained_name, _ in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
if not self.test_rust_tokenizer:
|
||||
self.skipTest(reason="No fast tokenizer defined")
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
|
||||
|
||||
# Find where to truncate, as the amount of tokens is different for different tokenizers and I want the
|
||||
# truncation to happen in the middle of the assistant content.
|
||||
full_encoding = tokenizer_r.apply_chat_template(
|
||||
conversations[0],
|
||||
chat_template=dummy_template,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
chat_string = tokenizer_r.apply_chat_template(
|
||||
conversations[0], tokenize=False, chat_template=dummy_template
|
||||
)
|
||||
truncation_position = full_encoding.char_to_token(chat_string.index(", long string to be truncated,"))
|
||||
|
||||
# check batched
|
||||
output = tokenizer_r.apply_chat_template(
|
||||
conversations,
|
||||
chat_template=dummy_template,
|
||||
tokenize=True,
|
||||
return_assistant_tokens_mask=True,
|
||||
max_length=truncation_position,
|
||||
truncation=True,
|
||||
return_dict=True,
|
||||
)
|
||||
for i, conv in enumerate(conversations):
|
||||
chat_string = tokenizer_r.apply_chat_template(conv, tokenize=False, chat_template=dummy_template)
|
||||
assistant_start = output.char_to_token(i, chat_string.index("start turn assistant"))
|
||||
|
||||
# assert 1 from assistant_start to the end because the rest is truncated.
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][i][assistant_start:],
|
||||
[1] * (len(output["assistant_masks"][i]) - assistant_start),
|
||||
)
|
||||
|
||||
# check not batched
|
||||
output = tokenizer_r.apply_chat_template(
|
||||
conversations[0],
|
||||
chat_template=dummy_template,
|
||||
tokenize=True,
|
||||
return_assistant_tokens_mask=True,
|
||||
return_dict=True,
|
||||
max_length=truncation_position,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
chat_string = tokenizer_r.apply_chat_template(
|
||||
conversations[0], tokenize=False, chat_template=dummy_template
|
||||
)
|
||||
assistant_start = output.char_to_token(0, chat_string.index("start turn assistant"))
|
||||
|
||||
# assert 1 from assistant_start to the end because the rest is truncated.
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][assistant_start:],
|
||||
[1] * (len(output["assistant_masks"]) - assistant_start),
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_continue_final_message(self):
|
||||
dummy_template = """
|
||||
|
||||
Reference in New Issue
Block a user