From 9d999481b2bb231de3c5980e407200bd0ce3ce4d Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 20 Mar 2024 15:50:22 +0000 Subject: [PATCH] Add correct batched handling for apply_chat_template (#29222) * Add correct batched handling for apply_chat_template * Fix warning method * Add error for incompatible options * expand tests * Add a skip for markuplm * Add skips for other layout models * Skip for LayoutLMv2 * Slightly update the warning message * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * typo fix * Update docstring for conversation kwarg * Update return docstring * Remove the warning, improve error message * Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/test_tokenization_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/test_tokenization_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Remove return_dict=None * Fix up some merge cruft * More merge cruft * Add another skip * Add another skip --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/tokenization_utils_base.py | 80 +++++++++++-------- .../test_tokenization_layoutlmv2.py | 4 + .../test_tokenization_layoutlmv3.py | 4 + .../layoutxlm/test_tokenization_layoutxlm.py | 4 + .../markuplm/test_tokenization_markuplm.py | 4 + tests/models/tapas/test_tokenization_tapas.py | 4 + tests/models/udop/test_tokenization_udop.py | 8 ++ tests/test_tokenization_common.py | 59 ++++++++++++-- 8 files changed, 127 insertions(+), 40 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 2ab70f2d53..e5cd6dce06 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1692,7 +1692,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): def apply_chat_template( self, - conversation: Union[List[Dict[str, str]], "Conversation"], + conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"], chat_template: Optional[str] = None, add_generation_prompt: bool = False, tokenize: bool = True, @@ -1703,15 +1703,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): return_dict: bool = False, tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs, - ) -> Union[str, List[int]]: + ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]: """ - Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token + Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to determine the format and control tokens to use when converting. When chat_template is None, it will fall back to the default_chat_template specified at the class level. Args: - conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts + conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"]): A list of dicts with "role" and "content" keys, representing the chat history so far. chat_template (str, *optional*): A Jinja template to use for this conversion. If this is not passed, the model's default chat template will be used instead. @@ -1735,19 +1735,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. - return_dict (`bool`, *optional*, defaults to `False`): + return_dict (`bool`, defaults to `False`): Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. Returns: - `List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This - output is ready to pass to the model, either directly or via methods like `generate()`. + `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This + output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is + set, will return a dict of tokenizer outputs instead. """ - if hasattr(conversation, "messages"): - # Indicates it's a Conversation object - conversation = conversation.messages + if return_dict and not tokenize: + raise ValueError( + "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict " + "of tokenizer outputs to return." + ) if tokenizer_kwargs is None: tokenizer_kwargs = {} @@ -1779,34 +1782,43 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): # Compilation function uses a cache to avoid recompiling the same template compiled_template = self._compile_jinja_template(chat_template) - template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present - rendered = compiled_template.render( - messages=conversation, add_generation_prompt=add_generation_prompt, **template_kwargs - ) + if isinstance(conversation, (list, tuple)) and ( + isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") + ): + conversations = conversation + is_batched = True + else: + conversations = [conversation] + is_batched = False + + rendered = [] + template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present + for chat in conversations: + if hasattr(chat, "messages"): + # Indicates it's a Conversation object + chat = chat.messages + rendered_chat = compiled_template.render( + messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs + ) + rendered.append(rendered_chat) + + if not is_batched: + rendered = rendered[0] - if padding is True: - padding = "max_length" # There's only one sequence here, so "longest" makes no sense if tokenize: + out = self( + rendered, + padding=padding, + truncation=truncation, + max_length=max_length, + add_special_tokens=False, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) if return_dict: - return self( - rendered, - padding=padding, - truncation=truncation, - max_length=max_length, - add_special_tokens=False, - return_tensors=return_tensors, - **tokenizer_kwargs, - ) + return out else: - return self.encode( - rendered, - padding=padding, - truncation=truncation, - max_length=max_length, - add_special_tokens=False, - return_tensors=return_tensors, - **tokenizer_kwargs, - ) + return out["input_ids"] else: return rendered diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py index 61fafa23da..ce6bbd0f01 100644 --- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py @@ -195,6 +195,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"] ) + @unittest.skip("Chat template tests don't play well with table/layout models.") + def test_chat_template_batched(self): + pass + def test_wordpiece_tokenizer(self): vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"] diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py index c7af4fbddc..80d29d3a46 100644 --- a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py @@ -140,6 +140,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase): output_text = "lower newer" return input_text, output_text + @unittest.skip("Chat template tests don't play well with table/layout models.") + def test_chat_template_batched(self): + pass + def test_full_tokenizer(self): tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map) text = "lower newer" diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index 474cc53171..8f1d353efd 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -107,6 +107,10 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): output_text = "unwanted, running" return input_text, output_text + @unittest.skip("Chat template tests don't play well with table/layout models.") + def test_chat_template_batched(self): + pass + # override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of # this tokenizer def test_save_sentencepiece_tokenizer(self) -> None: diff --git a/tests/models/markuplm/test_tokenization_markuplm.py b/tests/models/markuplm/test_tokenization_markuplm.py index df1b9ed083..370b1c5692 100644 --- a/tests/models/markuplm/test_tokenization_markuplm.py +++ b/tests/models/markuplm/test_tokenization_markuplm.py @@ -101,6 +101,10 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): return questions, nodes, xpaths + @unittest.skip("Chat template tests don't play well with table/layout models.") + def test_chat_template_batched(self): + pass + def get_input_output_texts(self, tokenizer): input_text = "UNwant\u00E9d,running" output_text = "unwanted, running" diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py index 9aa3ad6d58..4100e02b13 100644 --- a/tests/models/tapas/test_tokenization_tapas.py +++ b/tests/models/tapas/test_tokenization_tapas.py @@ -223,6 +223,10 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase): rust_ids = rust_tokenizer.encode(sequence) self.assertListEqual(ids, rust_ids) + @unittest.skip("Chat template tests don't play well with table/layout models.") + def test_chat_template_batched(self): + pass + def test_chinese(self): tokenizer = BasicTokenizer() diff --git a/tests/models/udop/test_tokenization_udop.py b/tests/models/udop/test_tokenization_udop.py index f9ad6c7abe..0519ee0622 100644 --- a/tests/models/udop/test_tokenization_udop.py +++ b/tests/models/udop/test_tokenization_udop.py @@ -1153,6 +1153,14 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase): # Assert there is online added_tokens special_tokens self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens) + @unittest.skip("Chat template tests don't play well with table/layout models.") + def test_chat_template(self): + pass + + @unittest.skip("Chat template tests don't play well with table/layout models.") + def test_chat_template_batched(self): + pass + @require_torch @slow def test_torch_encode_plus_sent_to_model(self): diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index fa1be251e0..8216db084c 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1104,26 +1104,73 @@ class TokenizerTesterMixin: for tokenizer in tokenizers: with self.subTest(f"{tokenizer.__class__.__name__}"): output = tokenizer.apply_chat_template( - dummy_conversation, chat_template=dummy_template, tokenize=False + dummy_conversation, chat_template=dummy_template, tokenize=False, return_dict=False ) self.assertEqual(output, expected_output) # Test we can pass chat_template arg + # Check that no error raised when tokenize=True - tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True) + output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False + ) + dict_output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True + ) + self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches tokenizer.chat_template = dummy_template self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter - output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False) + output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False) self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed - tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised + # Check that no error raised + tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) with tempfile.TemporaryDirectory() as tmp_dir_name: tokenizer.save_pretrained(tmp_dir_name) tokenizer = tokenizer.from_pretrained(tmp_dir_name) self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted - output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False) + output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False) self.assertEqual(output, expected_output) # Test output is the same after reloading - tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised + # Check that no error raised + tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) + + @require_jinja + def test_chat_template_batched(self): + dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}" + dummy_conversations = [ + [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "assistant message"}, + ], + [ + {"role": "system", "content": "system message 2"}, + {"role": "user", "content": "user message 2"}, + {"role": "assistant", "content": "assistant message 2"}, + ], + ] + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + output = tokenizer.apply_chat_template( + dummy_conversations, chat_template=dummy_template, tokenize=False + ) + self.assertEqual( + output, + [ + "systemsystem messageuseruser messageassistantassistant message", + "systemsystem message 2useruser message 2assistantassistant message 2", + ], + ) + one_element_output = tokenizer.apply_chat_template( + dummy_conversations[:1], chat_template=dummy_template, tokenize=False + ) + self.assertEqual( + one_element_output, ["systemsystem messageuseruser messageassistantassistant message"] + ) # Assert that list structure is retained even with one element + tokenizer.apply_chat_template( + dummy_conversations, chat_template=dummy_template, tokenize=True + ) # Check that no error raised @require_jinja def test_chat_template_dict(self):