* Fix assistant tokens when truncated * fix test * fix test * step
This commit is contained in:
@@ -1722,7 +1722,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
if start_token is None:
|
if start_token is None:
|
||||||
# start_token is out of bounds maybe due to truncation.
|
# start_token is out of bounds maybe due to truncation.
|
||||||
break
|
break
|
||||||
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids)):
|
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
|
||||||
current_mask[token_id] = 1
|
current_mask[token_id] = 1
|
||||||
assistant_masks.append(current_mask)
|
assistant_masks.append(current_mask)
|
||||||
out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0]
|
out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0]
|
||||||
|
|||||||
@@ -2497,3 +2497,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip("Chat is not supported")
|
@unittest.skip("Chat is not supported")
|
||||||
def test_chat_template_return_assistant_tokens_mask(self):
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask_truncated(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -2450,3 +2450,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip("Chat is not supported")
|
@unittest.skip("Chat is not supported")
|
||||||
def test_chat_template_return_assistant_tokens_mask(self):
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask_truncated(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1991,3 +1991,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip("Chat is not supported")
|
@unittest.skip("Chat is not supported")
|
||||||
def test_chat_template_return_assistant_tokens_mask(self):
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask_truncated(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -2330,3 +2330,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip("Chat is not supported")
|
@unittest.skip("Chat is not supported")
|
||||||
def test_chat_template_return_assistant_tokens_mask(self):
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask_truncated(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1290,3 +1290,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip("Chat is not supported")
|
@unittest.skip("Chat is not supported")
|
||||||
def test_chat_template_return_assistant_tokens_mask(self):
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask_truncated(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1161,6 +1161,10 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def test_chat_template_return_assistant_tokens_mask(self):
|
def test_chat_template_return_assistant_tokens_mask(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat is not supported")
|
||||||
|
def test_chat_template_return_assistant_tokens_mask_truncated(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
|
@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
|
||||||
def test_chat_template_batched(self):
|
def test_chat_template_batched(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1327,6 +1327,110 @@ class TokenizerTesterMixin:
|
|||||||
[0] * (assistant_start2 - assistant_end - 1),
|
[0] * (assistant_start2 - assistant_end - 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_jinja
|
||||||
|
def test_chat_template_return_assistant_tokens_mask_truncated(self):
|
||||||
|
dummy_template = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{% if (message['role'] != 'assistant') %}"
|
||||||
|
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
||||||
|
"{% elif (message['role'] == 'assistant')%}"
|
||||||
|
"{{'<|im_start|>' + message['role'] + '\n'}}"
|
||||||
|
"{% generation %}"
|
||||||
|
"{{message['content'] + '<|im_end|>'}}"
|
||||||
|
"{% endgeneration %}"
|
||||||
|
"{{'\n'}}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
)
|
||||||
|
conversations = [
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "system message"},
|
||||||
|
{"role": "user", "content": "user message"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": (
|
||||||
|
"start turn assistant. long string to be truncated, long string to be truncated, "
|
||||||
|
"long string to be truncated, long string to be truncated, long string to be truncated"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "another user message"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "system message"},
|
||||||
|
{"role": "user", "content": "user message"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": (
|
||||||
|
"start turn assistant. long string to be truncated, long string to be truncated, "
|
||||||
|
"long string to be truncated, long string to be truncated, long string to be truncated"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "another user message"},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
for tokenizer, pretrained_name, _ in self.tokenizers_list:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||||
|
if not self.test_rust_tokenizer:
|
||||||
|
self.skipTest(reason="No fast tokenizer defined")
|
||||||
|
|
||||||
|
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
|
||||||
|
|
||||||
|
# Find where to truncate, as the amount of tokens is different for different tokenizers and I want the
|
||||||
|
# truncation to happen in the middle of the assistant content.
|
||||||
|
full_encoding = tokenizer_r.apply_chat_template(
|
||||||
|
conversations[0],
|
||||||
|
chat_template=dummy_template,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
chat_string = tokenizer_r.apply_chat_template(
|
||||||
|
conversations[0], tokenize=False, chat_template=dummy_template
|
||||||
|
)
|
||||||
|
truncation_position = full_encoding.char_to_token(chat_string.index(", long string to be truncated,"))
|
||||||
|
|
||||||
|
# check batched
|
||||||
|
output = tokenizer_r.apply_chat_template(
|
||||||
|
conversations,
|
||||||
|
chat_template=dummy_template,
|
||||||
|
tokenize=True,
|
||||||
|
return_assistant_tokens_mask=True,
|
||||||
|
max_length=truncation_position,
|
||||||
|
truncation=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
for i, conv in enumerate(conversations):
|
||||||
|
chat_string = tokenizer_r.apply_chat_template(conv, tokenize=False, chat_template=dummy_template)
|
||||||
|
assistant_start = output.char_to_token(i, chat_string.index("start turn assistant"))
|
||||||
|
|
||||||
|
# assert 1 from assistant_start to the end because the rest is truncated.
|
||||||
|
self.assertEqual(
|
||||||
|
output["assistant_masks"][i][assistant_start:],
|
||||||
|
[1] * (len(output["assistant_masks"][i]) - assistant_start),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check not batched
|
||||||
|
output = tokenizer_r.apply_chat_template(
|
||||||
|
conversations[0],
|
||||||
|
chat_template=dummy_template,
|
||||||
|
tokenize=True,
|
||||||
|
return_assistant_tokens_mask=True,
|
||||||
|
return_dict=True,
|
||||||
|
max_length=truncation_position,
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_string = tokenizer_r.apply_chat_template(
|
||||||
|
conversations[0], tokenize=False, chat_template=dummy_template
|
||||||
|
)
|
||||||
|
assistant_start = output.char_to_token(0, chat_string.index("start turn assistant"))
|
||||||
|
|
||||||
|
# assert 1 from assistant_start to the end because the rest is truncated.
|
||||||
|
self.assertEqual(
|
||||||
|
output["assistant_masks"][assistant_start:],
|
||||||
|
[1] * (len(output["assistant_masks"]) - assistant_start),
|
||||||
|
)
|
||||||
|
|
||||||
@require_jinja
|
@require_jinja
|
||||||
def test_continue_final_message(self):
|
def test_continue_final_message(self):
|
||||||
dummy_template = """
|
dummy_template = """
|
||||||
|
|||||||
Reference in New Issue
Block a user