Add correct batched handling for apply_chat_template (#29222)
* Add correct batched handling for apply_chat_template * Fix warning method * Add error for incompatible options * expand tests * Add a skip for markuplm * Add skips for other layout models * Skip for LayoutLMv2 * Slightly update the warning message * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * typo fix * Update docstring for conversation kwarg * Update return docstring * Remove the warning, improve error message * Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/test_tokenization_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/test_tokenization_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Remove return_dict=None * Fix up some merge cruft * More merge cruft * Add another skip * Add another skip --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -1104,26 +1104,73 @@ class TokenizerTesterMixin:
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, return_dict=False
|
||||
)
|
||||
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
|
||||
|
||||
# Check that no error raised when tokenize=True
|
||||
tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True)
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False
|
||||
)
|
||||
dict_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True
|
||||
)
|
||||
self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches
|
||||
|
||||
tokenizer.chat_template = dummy_template
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
||||
# Check that no error raised
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
||||
# Check that no error raised
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_batched(self):
|
||||
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
|
||||
dummy_conversations = [
|
||||
[
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
],
|
||||
[
|
||||
{"role": "system", "content": "system message 2"},
|
||||
{"role": "user", "content": "user message 2"},
|
||||
{"role": "assistant", "content": "assistant message 2"},
|
||||
],
|
||||
]
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversations, chat_template=dummy_template, tokenize=False
|
||||
)
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
"systemsystem messageuseruser messageassistantassistant message",
|
||||
"systemsystem message 2useruser message 2assistantassistant message 2",
|
||||
],
|
||||
)
|
||||
one_element_output = tokenizer.apply_chat_template(
|
||||
dummy_conversations[:1], chat_template=dummy_template, tokenize=False
|
||||
)
|
||||
self.assertEqual(
|
||||
one_element_output, ["systemsystem messageuseruser messageassistantassistant message"]
|
||||
) # Assert that list structure is retained even with one element
|
||||
tokenizer.apply_chat_template(
|
||||
dummy_conversations, chat_template=dummy_template, tokenize=True
|
||||
) # Check that no error raised
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_dict(self):
|
||||
|
||||
Reference in New Issue
Block a user