Separate chat templates into a single file (#33957)
* Initial draft * Add .jinja file loading for processors * Add processor saving of naked chat template files * make fixup * Add save-load test for tokenizers * Add save-load test for tokenizers * stash commit * Try popping the file * make fixup * Pop the arg correctly * Pop the arg correctly * Add processor test * Fix processor code * stash commit * Processor clobbers child tokenizer's chat template * Processor clobbers child tokenizer's chat template * make fixup * Split processor/tokenizer files to avoid interactions * fix test * Expand processor tests * Rename arg to "save_raw_chat_template" across all classes * Update processor warning * Move templates to single file * Move templates to single file * Improve testing for processor/tokenizer clashes * Improve testing for processor/tokenizer clashes * Extend saving test * Test file priority correctly * make fixup * Don't pop the chat template file before the slow tokenizer gets a look * Remove breakpoint * make fixup * Fix error
This commit is contained in:
@@ -18,6 +18,7 @@ import inspect
|
||||
import json
|
||||
import random
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -519,3 +520,27 @@ class ProcessorTesterMixin:
|
||||
processor.prepare_and_validate_optional_call_args(
|
||||
*(f"optional_{i}" for i in range(num_optional_call_args + 1))
|
||||
)
|
||||
|
||||
def test_chat_template_save_loading(self):
|
||||
processor = self.get_processor()
|
||||
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
|
||||
processor.chat_template = "test template"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
processor.save_pretrained(tmpdirname)
|
||||
self.assertTrue(Path(tmpdirname, "chat_template.json").is_file())
|
||||
self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file())
|
||||
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
|
||||
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
|
||||
# When we don't use single-file chat template saving, processor and tokenizer chat templates
|
||||
# should remain separate
|
||||
self.assertEqual(getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
processor.save_pretrained(tmpdirname, save_raw_chat_template=True)
|
||||
self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file())
|
||||
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
|
||||
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
|
||||
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
|
||||
# When we save as single files, tokenizers and processors share a chat template, which means
|
||||
# the reloaded tokenizer should get the chat template as well
|
||||
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
|
||||
|
||||
@@ -25,6 +25,7 @@ import traceback
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from itertools import takewhile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
from parameterized import parameterized
|
||||
@@ -1107,13 +1108,29 @@ class TokenizerTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||
# Check that no error raised
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
|
||||
chat_template_file = Path(tmp_dir_name) / "chat_template.jinja"
|
||||
self.assertTrue(chat_template_file.is_file())
|
||||
self.assertEqual(chat_template_file.read_text(), dummy_template)
|
||||
config_dict = json.loads((Path(tmp_dir_name) / "tokenizer_config.json").read_text())
|
||||
# Assert the chat template is not in the config when it's saved as a separate file
|
||||
self.assertNotIn("chat_template", config_dict)
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||
# Check that no error raised
|
||||
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_batched(self):
|
||||
@@ -1526,18 +1543,40 @@ class TokenizerTesterMixin:
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||
for save_raw_chat_template in (True, False):
|
||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
# Test that save_raw_chat_template is ignored when there's a dict of multiple templates
|
||||
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template)
|
||||
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||
# Assert that chat templates are correctly serialized as lists of dictionaries
|
||||
self.assertEqual(
|
||||
config_dict["chat_template"],
|
||||
[
|
||||
{"name": "template1", "template": "{{'a'}}"},
|
||||
{"name": "template2", "template": "{{'b'}}"},
|
||||
],
|
||||
)
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
# Assert that the serialized list is correctly reconstructed as a single dict
|
||||
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_file_priority(self):
|
||||
dummy_template1 = "a"
|
||||
dummy_template2 = "b"
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||
# Assert that chat templates are correctly serialized as lists of dictionaries
|
||||
self.assertEqual(
|
||||
config_dict["chat_template"],
|
||||
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
|
||||
)
|
||||
tokenizer.chat_template = dummy_template1
|
||||
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=False)
|
||||
with Path(tmp_dir_name, "chat_template.jinja").open("w") as f:
|
||||
f.write(dummy_template2)
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
# Assert that the serialized list is correctly reconstructed as a single dict
|
||||
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
||||
# Assert the file template clobbers any template in the config
|
||||
self.assertEqual(new_tokenizer.chat_template, dummy_template2)
|
||||
|
||||
def test_number_of_added_tokens(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
|
||||
Reference in New Issue
Block a user