From 082e57e0d42a46a4a5244f8d005eeb88c8da37b6 Mon Sep 17 00:00:00 2001 From: Yoni Gottesman Date: Tue, 5 Nov 2024 17:10:15 +0200 Subject: [PATCH] Fix #34494 assistant tokens when truncated (#34531) * Fix assistant tokens when truncated * fix test * fix test * step --- src/transformers/tokenization_utils_base.py | 2 +- .../test_tokenization_layoutlmv2.py | 4 + .../test_tokenization_layoutlmv3.py | 4 + .../layoutxlm/test_tokenization_layoutxlm.py | 4 + .../markuplm/test_tokenization_markuplm.py | 4 + tests/models/tapas/test_tokenization_tapas.py | 4 + tests/models/udop/test_tokenization_udop.py | 4 + tests/test_tokenization_common.py | 104 ++++++++++++++++++ 8 files changed, 129 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 43e13abe56..381f3ef497 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1722,7 +1722,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): if start_token is None: # start_token is out of bounds maybe due to truncation. break - for token_id in range(start_token, end_token + 1 if end_token else len(input_ids)): + for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])): current_mask[token_id] = 1 assistant_masks.append(current_mask) out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0] diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py index 9e39cd0279..7dcf539970 100644 --- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py @@ -2497,3 +2497,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): @unittest.skip("Chat is not supported") def test_chat_template_return_assistant_tokens_mask(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask_truncated(self): + pass diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py index 4a218d3f21..9af0861536 100644 --- a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py @@ -2450,3 +2450,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase): @unittest.skip("Chat is not supported") def test_chat_template_return_assistant_tokens_mask(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask_truncated(self): + pass diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index 9f6d65ffc5..f387e52790 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -1991,3 +1991,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): @unittest.skip("Chat is not supported") def test_chat_template_return_assistant_tokens_mask(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask_truncated(self): + pass diff --git a/tests/models/markuplm/test_tokenization_markuplm.py b/tests/models/markuplm/test_tokenization_markuplm.py index 60c98776b2..eaf30131d3 100644 --- a/tests/models/markuplm/test_tokenization_markuplm.py +++ b/tests/models/markuplm/test_tokenization_markuplm.py @@ -2330,3 +2330,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): @unittest.skip("Chat is not supported") def test_chat_template_return_assistant_tokens_mask(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask_truncated(self): + pass diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py index 49327a39cd..0a911f7182 100644 --- a/tests/models/tapas/test_tokenization_tapas.py +++ b/tests/models/tapas/test_tokenization_tapas.py @@ -1290,3 +1290,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase): @unittest.skip("Chat is not supported") def test_chat_template_return_assistant_tokens_mask(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask_truncated(self): + pass diff --git a/tests/models/udop/test_tokenization_udop.py b/tests/models/udop/test_tokenization_udop.py index 90d669064a..a6ac2ff3d3 100644 --- a/tests/models/udop/test_tokenization_udop.py +++ b/tests/models/udop/test_tokenization_udop.py @@ -1161,6 +1161,10 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_chat_template_return_assistant_tokens_mask(self): pass + @unittest.skip("Chat is not supported") + def test_chat_template_return_assistant_tokens_mask_truncated(self): + pass + @unittest.skip(reason="Chat template tests don't play well with table/layout models.") def test_chat_template_batched(self): pass diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index dd9eb10de4..a3bbbf3c9e 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1327,6 +1327,110 @@ class TokenizerTesterMixin: [0] * (assistant_start2 - assistant_end - 1), ) + @require_jinja + def test_chat_template_return_assistant_tokens_mask_truncated(self): + dummy_template = ( + "{% for message in messages %}" + "{% if (message['role'] != 'assistant') %}" + "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" + "{% elif (message['role'] == 'assistant')%}" + "{{'<|im_start|>' + message['role'] + '\n'}}" + "{% generation %}" + "{{message['content'] + '<|im_end|>'}}" + "{% endgeneration %}" + "{{'\n'}}" + "{% endif %}" + "{% endfor %}" + ) + conversations = [ + [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + { + "role": "assistant", + "content": ( + "start turn assistant. long string to be truncated, long string to be truncated, " + "long string to be truncated, long string to be truncated, long string to be truncated" + ), + }, + {"role": "user", "content": "another user message"}, + ], + [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + { + "role": "assistant", + "content": ( + "start turn assistant. long string to be truncated, long string to be truncated, " + "long string to be truncated, long string to be truncated, long string to be truncated" + ), + }, + {"role": "user", "content": "another user message"}, + ], + ] + + for tokenizer, pretrained_name, _ in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + if not self.test_rust_tokenizer: + self.skipTest(reason="No fast tokenizer defined") + + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name) + + # Find where to truncate, as the amount of tokens is different for different tokenizers and I want the + # truncation to happen in the middle of the assistant content. + full_encoding = tokenizer_r.apply_chat_template( + conversations[0], + chat_template=dummy_template, + tokenize=True, + return_dict=True, + ) + chat_string = tokenizer_r.apply_chat_template( + conversations[0], tokenize=False, chat_template=dummy_template + ) + truncation_position = full_encoding.char_to_token(chat_string.index(", long string to be truncated,")) + + # check batched + output = tokenizer_r.apply_chat_template( + conversations, + chat_template=dummy_template, + tokenize=True, + return_assistant_tokens_mask=True, + max_length=truncation_position, + truncation=True, + return_dict=True, + ) + for i, conv in enumerate(conversations): + chat_string = tokenizer_r.apply_chat_template(conv, tokenize=False, chat_template=dummy_template) + assistant_start = output.char_to_token(i, chat_string.index("start turn assistant")) + + # assert 1 from assistant_start to the end because the rest is truncated. + self.assertEqual( + output["assistant_masks"][i][assistant_start:], + [1] * (len(output["assistant_masks"][i]) - assistant_start), + ) + + # check not batched + output = tokenizer_r.apply_chat_template( + conversations[0], + chat_template=dummy_template, + tokenize=True, + return_assistant_tokens_mask=True, + return_dict=True, + max_length=truncation_position, + truncation=True, + ) + + chat_string = tokenizer_r.apply_chat_template( + conversations[0], tokenize=False, chat_template=dummy_template + ) + assistant_start = output.char_to_token(0, chat_string.index("start turn assistant")) + + # assert 1 from assistant_start to the end because the rest is truncated. + self.assertEqual( + output["assistant_masks"][assistant_start:], + [1] * (len(output["assistant_masks"]) - assistant_start), + ) + @require_jinja def test_continue_final_message(self): dummy_template = """