Separate chat templates into a single file (#33957)
* Initial draft * Add .jinja file loading for processors * Add processor saving of naked chat template files * make fixup * Add save-load test for tokenizers * Add save-load test for tokenizers * stash commit * Try popping the file * make fixup * Pop the arg correctly * Pop the arg correctly * Add processor test * Fix processor code * stash commit * Processor clobbers child tokenizer's chat template * Processor clobbers child tokenizer's chat template * make fixup * Split processor/tokenizer files to avoid interactions * fix test * Expand processor tests * Rename arg to "save_raw_chat_template" across all classes * Update processor warning * Move templates to single file * Move templates to single file * Improve testing for processor/tokenizer clashes * Improve testing for processor/tokenizer clashes * Extend saving test * Test file priority correctly * make fixup * Don't pop the chat template file before the slow tokenizer gets a look * Remove breakpoint * make fixup * Fix error
This commit is contained in:
@@ -25,6 +25,7 @@ import traceback
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from itertools import takewhile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
from parameterized import parameterized
|
||||
@@ -1107,13 +1108,29 @@ class TokenizerTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||
# Check that no error raised
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
|
||||
chat_template_file = Path(tmp_dir_name) / "chat_template.jinja"
|
||||
self.assertTrue(chat_template_file.is_file())
|
||||
self.assertEqual(chat_template_file.read_text(), dummy_template)
|
||||
config_dict = json.loads((Path(tmp_dir_name) / "tokenizer_config.json").read_text())
|
||||
# Assert the chat template is not in the config when it's saved as a separate file
|
||||
self.assertNotIn("chat_template", config_dict)
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||
# Check that no error raised
|
||||
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_batched(self):
|
||||
@@ -1526,18 +1543,40 @@ class TokenizerTesterMixin:
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||
for save_raw_chat_template in (True, False):
|
||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
# Test that save_raw_chat_template is ignored when there's a dict of multiple templates
|
||||
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template)
|
||||
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||
# Assert that chat templates are correctly serialized as lists of dictionaries
|
||||
self.assertEqual(
|
||||
config_dict["chat_template"],
|
||||
[
|
||||
{"name": "template1", "template": "{{'a'}}"},
|
||||
{"name": "template2", "template": "{{'b'}}"},
|
||||
],
|
||||
)
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
# Assert that the serialized list is correctly reconstructed as a single dict
|
||||
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_file_priority(self):
|
||||
dummy_template1 = "a"
|
||||
dummy_template2 = "b"
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||
# Assert that chat templates are correctly serialized as lists of dictionaries
|
||||
self.assertEqual(
|
||||
config_dict["chat_template"],
|
||||
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
|
||||
)
|
||||
tokenizer.chat_template = dummy_template1
|
||||
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=False)
|
||||
with Path(tmp_dir_name, "chat_template.jinja").open("w") as f:
|
||||
f.write(dummy_template2)
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
# Assert that the serialized list is correctly reconstructed as a single dict
|
||||
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
||||
# Assert the file template clobbers any template in the config
|
||||
self.assertEqual(new_tokenizer.chat_template, dummy_template2)
|
||||
|
||||
def test_number_of_added_tokens(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
|
||||
Reference in New Issue
Block a user