Add correct batched handling for apply_chat_template (#29222)
* Add correct batched handling for apply_chat_template * Fix warning method * Add error for incompatible options * expand tests * Add a skip for markuplm * Add skips for other layout models * Skip for LayoutLMv2 * Slightly update the warning message * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * typo fix * Update docstring for conversation kwarg * Update return docstring * Remove the warning, improve error message * Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/test_tokenization_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/test_tokenization_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Remove return_dict=None * Fix up some merge cruft * More merge cruft * Add another skip * Add another skip --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -1692,7 +1692,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
|
|
||||||
def apply_chat_template(
|
def apply_chat_template(
|
||||||
self,
|
self,
|
||||||
conversation: Union[List[Dict[str, str]], "Conversation"],
|
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
|
||||||
chat_template: Optional[str] = None,
|
chat_template: Optional[str] = None,
|
||||||
add_generation_prompt: bool = False,
|
add_generation_prompt: bool = False,
|
||||||
tokenize: bool = True,
|
tokenize: bool = True,
|
||||||
@@ -1703,15 +1703,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
return_dict: bool = False,
|
return_dict: bool = False,
|
||||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, List[int]]:
|
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
|
||||||
"""
|
"""
|
||||||
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token
|
Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
|
||||||
ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
|
ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
|
||||||
determine the format and control tokens to use when converting. When chat_template is None, it will fall back
|
determine the format and control tokens to use when converting. When chat_template is None, it will fall back
|
||||||
to the default_chat_template specified at the class level.
|
to the default_chat_template specified at the class level.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts
|
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"]): A list of dicts
|
||||||
with "role" and "content" keys, representing the chat history so far.
|
with "role" and "content" keys, representing the chat history so far.
|
||||||
chat_template (str, *optional*): A Jinja template to use for this conversion. If
|
chat_template (str, *optional*): A Jinja template to use for this conversion. If
|
||||||
this is not passed, the model's default chat template will be used instead.
|
this is not passed, the model's default chat template will be used instead.
|
||||||
@@ -1735,19 +1735,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||||
return_dict (`bool`, *optional*, defaults to `False`):
|
return_dict (`bool`, defaults to `False`):
|
||||||
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
||||||
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
|
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
|
||||||
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
|
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This
|
`Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
|
||||||
output is ready to pass to the model, either directly or via methods like `generate()`.
|
output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
|
||||||
|
set, will return a dict of tokenizer outputs instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if hasattr(conversation, "messages"):
|
if return_dict and not tokenize:
|
||||||
# Indicates it's a Conversation object
|
raise ValueError(
|
||||||
conversation = conversation.messages
|
"`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
|
||||||
|
"of tokenizer outputs to return."
|
||||||
|
)
|
||||||
|
|
||||||
if tokenizer_kwargs is None:
|
if tokenizer_kwargs is None:
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
@@ -1779,34 +1782,43 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
# Compilation function uses a cache to avoid recompiling the same template
|
# Compilation function uses a cache to avoid recompiling the same template
|
||||||
compiled_template = self._compile_jinja_template(chat_template)
|
compiled_template = self._compile_jinja_template(chat_template)
|
||||||
|
|
||||||
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
|
if isinstance(conversation, (list, tuple)) and (
|
||||||
rendered = compiled_template.render(
|
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
|
||||||
messages=conversation, add_generation_prompt=add_generation_prompt, **template_kwargs
|
):
|
||||||
)
|
conversations = conversation
|
||||||
|
is_batched = True
|
||||||
|
else:
|
||||||
|
conversations = [conversation]
|
||||||
|
is_batched = False
|
||||||
|
|
||||||
|
rendered = []
|
||||||
|
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
|
||||||
|
for chat in conversations:
|
||||||
|
if hasattr(chat, "messages"):
|
||||||
|
# Indicates it's a Conversation object
|
||||||
|
chat = chat.messages
|
||||||
|
rendered_chat = compiled_template.render(
|
||||||
|
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
|
||||||
|
)
|
||||||
|
rendered.append(rendered_chat)
|
||||||
|
|
||||||
|
if not is_batched:
|
||||||
|
rendered = rendered[0]
|
||||||
|
|
||||||
if padding is True:
|
|
||||||
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
|
|
||||||
if tokenize:
|
if tokenize:
|
||||||
|
out = self(
|
||||||
|
rendered,
|
||||||
|
padding=padding,
|
||||||
|
truncation=truncation,
|
||||||
|
max_length=max_length,
|
||||||
|
add_special_tokens=False,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
if return_dict:
|
if return_dict:
|
||||||
return self(
|
return out
|
||||||
rendered,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
max_length=max_length,
|
|
||||||
add_special_tokens=False,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return self.encode(
|
return out["input_ids"]
|
||||||
rendered,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
max_length=max_length,
|
|
||||||
add_special_tokens=False,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return rendered
|
return rendered
|
||||||
|
|
||||||
|
|||||||
@@ -195,6 +195,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
|
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skip("Chat template tests don't play well with table/layout models.")
|
||||||
|
def test_chat_template_batched(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_wordpiece_tokenizer(self):
|
def test_wordpiece_tokenizer(self):
|
||||||
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
|
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
|
||||||
|
|
||||||
|
|||||||
@@ -140,6 +140,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
output_text = "lower newer"
|
output_text = "lower newer"
|
||||||
return input_text, output_text
|
return input_text, output_text
|
||||||
|
|
||||||
|
@unittest.skip("Chat template tests don't play well with table/layout models.")
|
||||||
|
def test_chat_template_batched(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||||
text = "lower newer"
|
text = "lower newer"
|
||||||
|
|||||||
@@ -107,6 +107,10 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
output_text = "unwanted, running"
|
output_text = "unwanted, running"
|
||||||
return input_text, output_text
|
return input_text, output_text
|
||||||
|
|
||||||
|
@unittest.skip("Chat template tests don't play well with table/layout models.")
|
||||||
|
def test_chat_template_batched(self):
|
||||||
|
pass
|
||||||
|
|
||||||
# override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of
|
# override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of
|
||||||
# this tokenizer
|
# this tokenizer
|
||||||
def test_save_sentencepiece_tokenizer(self) -> None:
|
def test_save_sentencepiece_tokenizer(self) -> None:
|
||||||
|
|||||||
@@ -101,6 +101,10 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
return questions, nodes, xpaths
|
return questions, nodes, xpaths
|
||||||
|
|
||||||
|
@unittest.skip("Chat template tests don't play well with table/layout models.")
|
||||||
|
def test_chat_template_batched(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def get_input_output_texts(self, tokenizer):
|
def get_input_output_texts(self, tokenizer):
|
||||||
input_text = "UNwant\u00E9d,running"
|
input_text = "UNwant\u00E9d,running"
|
||||||
output_text = "unwanted, running"
|
output_text = "unwanted, running"
|
||||||
|
|||||||
@@ -223,6 +223,10 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
rust_ids = rust_tokenizer.encode(sequence)
|
rust_ids = rust_tokenizer.encode(sequence)
|
||||||
self.assertListEqual(ids, rust_ids)
|
self.assertListEqual(ids, rust_ids)
|
||||||
|
|
||||||
|
@unittest.skip("Chat template tests don't play well with table/layout models.")
|
||||||
|
def test_chat_template_batched(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_chinese(self):
|
def test_chinese(self):
|
||||||
tokenizer = BasicTokenizer()
|
tokenizer = BasicTokenizer()
|
||||||
|
|
||||||
|
|||||||
@@ -1153,6 +1153,14 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# Assert there is online added_tokens special_tokens
|
# Assert there is online added_tokens special_tokens
|
||||||
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
|
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
|
||||||
|
|
||||||
|
@unittest.skip("Chat template tests don't play well with table/layout models.")
|
||||||
|
def test_chat_template(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Chat template tests don't play well with table/layout models.")
|
||||||
|
def test_chat_template_batched(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_torch_encode_plus_sent_to_model(self):
|
def test_torch_encode_plus_sent_to_model(self):
|
||||||
|
|||||||
@@ -1104,26 +1104,73 @@ class TokenizerTesterMixin:
|
|||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
output = tokenizer.apply_chat_template(
|
output = tokenizer.apply_chat_template(
|
||||||
dummy_conversation, chat_template=dummy_template, tokenize=False
|
dummy_conversation, chat_template=dummy_template, tokenize=False, return_dict=False
|
||||||
)
|
)
|
||||||
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
|
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
|
||||||
|
|
||||||
# Check that no error raised when tokenize=True
|
# Check that no error raised when tokenize=True
|
||||||
tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True)
|
output = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False
|
||||||
|
)
|
||||||
|
dict_output = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True
|
||||||
|
)
|
||||||
|
self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches
|
||||||
|
|
||||||
tokenizer.chat_template = dummy_template
|
tokenizer.chat_template = dummy_template
|
||||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
|
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
|
||||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
|
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||||
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
|
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
|
||||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
# Check that no error raised
|
||||||
|
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
tokenizer.save_pretrained(tmp_dir_name)
|
tokenizer.save_pretrained(tmp_dir_name)
|
||||||
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||||
|
|
||||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
|
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
|
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
# Check that no error raised
|
||||||
|
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||||
|
|
||||||
|
@require_jinja
|
||||||
|
def test_chat_template_batched(self):
|
||||||
|
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
|
||||||
|
dummy_conversations = [
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "system message"},
|
||||||
|
{"role": "user", "content": "user message"},
|
||||||
|
{"role": "assistant", "content": "assistant message"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "system message 2"},
|
||||||
|
{"role": "user", "content": "user message 2"},
|
||||||
|
{"role": "assistant", "content": "assistant message 2"},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
tokenizers = self.get_tokenizers()
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
output = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversations, chat_template=dummy_template, tokenize=False
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
[
|
||||||
|
"systemsystem messageuseruser messageassistantassistant message",
|
||||||
|
"systemsystem message 2useruser message 2assistantassistant message 2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
one_element_output = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversations[:1], chat_template=dummy_template, tokenize=False
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
one_element_output, ["systemsystem messageuseruser messageassistantassistant message"]
|
||||||
|
) # Assert that list structure is retained even with one element
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
dummy_conversations, chat_template=dummy_template, tokenize=True
|
||||||
|
) # Check that no error raised
|
||||||
|
|
||||||
@require_jinja
|
@require_jinja
|
||||||
def test_chat_template_dict(self):
|
def test_chat_template_dict(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user