Add correct batched handling for apply_chat_template (#29222)

* Add correct batched handling for apply_chat_template

* Fix warning method

* Add error for incompatible options

* expand tests

* Add a skip for markuplm

* Add skips for other layout models

* Skip for LayoutLMv2

* Slightly update the warning message

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* typo fix

* Update docstring for conversation kwarg

* Update return docstring

* Remove the warning, improve error message

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/test_tokenization_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/test_tokenization_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Remove return_dict=None

* Fix up some merge cruft

* More merge cruft

* Add another skip

* Add another skip

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Matt
2024-03-20 15:50:22 +00:00
committed by GitHub
parent 3c17c529cc
commit 9d999481b2
8 changed files with 127 additions and 40 deletions

View File

@@ -1104,26 +1104,73 @@ class TokenizerTesterMixin:
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False
dummy_conversation, chat_template=dummy_template, tokenize=False, return_dict=False
)
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
# Check that no error raised when tokenize=True
tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True)
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False
)
dict_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True
)
self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches
tokenizer.chat_template = dummy_template
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
# Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
# Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
@require_jinja
def test_chat_template_batched(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
dummy_conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
],
[
{"role": "system", "content": "system message 2"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "assistant message 2"},
],
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversations, chat_template=dummy_template, tokenize=False
)
self.assertEqual(
output,
[
"systemsystem messageuseruser messageassistantassistant message",
"systemsystem message 2useruser message 2assistantassistant message 2",
],
)
one_element_output = tokenizer.apply_chat_template(
dummy_conversations[:1], chat_template=dummy_template, tokenize=False
)
self.assertEqual(
one_element_output, ["systemsystem messageuseruser messageassistantassistant message"]
) # Assert that list structure is retained even with one element
tokenizer.apply_chat_template(
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised
@require_jinja
def test_chat_template_dict(self):