Separate chat templates into a single file (#33957)

* Initial draft

* Add .jinja file loading for processors

* Add processor saving of naked chat template files

* make fixup

* Add save-load test for tokenizers

* Add save-load test for tokenizers

* stash commit

* Try popping the file

* make fixup

* Pop the arg correctly

* Pop the arg correctly

* Add processor test

* Fix processor code

* stash commit

* Processor clobbers child tokenizer's chat template

* Processor clobbers child tokenizer's chat template

* make fixup

* Split processor/tokenizer files to avoid interactions

* fix test

* Expand processor tests

* Rename arg to "save_raw_chat_template" across all classes

* Update processor warning

* Move templates to single file

* Move templates to single file

* Improve testing for processor/tokenizer clashes

* Improve testing for processor/tokenizer clashes

* Extend saving test

* Test file priority correctly

* make fixup

* Don't pop the chat template file before the slow tokenizer gets a look

* Remove breakpoint

* make fixup

* Fix error
This commit is contained in:
Matt
2024-11-26 14:18:04 +00:00
committed by GitHub
parent 5a45617887
commit d5cf91b346
4 changed files with 135 additions and 27 deletions

View File

@@ -44,7 +44,6 @@ from .tokenization_utils_base import (
TruncationStrategy, TruncationStrategy,
) )
from .utils import ( from .utils import (
CHAT_TEMPLATE_NAME,
PROCESSOR_NAME, PROCESSOR_NAME,
PushToHubMixin, PushToHubMixin,
TensorType, TensorType,
@@ -527,18 +526,24 @@ class ProcessorMixin(PushToHubMixin):
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
# plus we save chat_template in its own file # plus we save chat_template in its own file
output_processor_file = os.path.join(save_directory, PROCESSOR_NAME) output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME) output_raw_chat_template_file = os.path.join(save_directory, "chat_template.jinja")
output_chat_template_file = os.path.join(save_directory, "chat_template.json")
processor_dict = self.to_dict() processor_dict = self.to_dict()
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict` # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
# to avoid serializing chat template in json config file. So let's get it from `self` directly # to avoid serializing chat template in json config file. So let's get it from `self` directly
if self.chat_template is not None: if self.chat_template is not None:
chat_template_json_string = ( if kwargs.get("save_raw_chat_template", False):
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n" with open(output_raw_chat_template_file, "w", encoding="utf-8") as writer:
) writer.write(self.chat_template)
with open(output_chat_template_file, "w", encoding="utf-8") as writer: logger.info(f"chat template saved in {output_raw_chat_template_file}")
writer.write(chat_template_json_string) else:
logger.info(f"chat template saved in {output_chat_template_file}") chat_template_json_string = (
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
)
with open(output_chat_template_file, "w", encoding="utf-8") as writer:
writer.write(chat_template_json_string)
logger.info(f"chat template saved in {output_chat_template_file}")
# For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
# `auto_map` is not specified. # `auto_map` is not specified.
@@ -601,21 +606,23 @@ class ProcessorMixin(PushToHubMixin):
is_local = os.path.isdir(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME) processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME)
chat_template_file = os.path.join(pretrained_model_name_or_path, "chat_template.json")
if os.path.isfile(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path):
resolved_processor_file = pretrained_model_name_or_path resolved_processor_file = pretrained_model_name_or_path
# cant't load chat-template when given a file as pretrained_model_name_or_path # cant't load chat-template when given a file as pretrained_model_name_or_path
resolved_chat_template_file = None resolved_chat_template_file = None
resolved_raw_chat_template_file = None
is_local = True is_local = True
elif is_remote_url(pretrained_model_name_or_path): elif is_remote_url(pretrained_model_name_or_path):
processor_file = pretrained_model_name_or_path processor_file = pretrained_model_name_or_path
resolved_processor_file = download_url(pretrained_model_name_or_path) resolved_processor_file = download_url(pretrained_model_name_or_path)
# can't load chat-template when given a file url as pretrained_model_name_or_path # can't load chat-template when given a file url as pretrained_model_name_or_path
resolved_chat_template_file = None resolved_chat_template_file = None
resolved_raw_chat_template_file = None
else: else:
processor_file = PROCESSOR_NAME processor_file = PROCESSOR_NAME
chat_template_file = CHAT_TEMPLATE_NAME chat_template_file = "chat_template.json"
raw_chat_template_file = "chat_template.jinja"
try: try:
# Load from local folder or from cache or download from model Hub and cache # Load from local folder or from cache or download from model Hub and cache
resolved_processor_file = cached_file( resolved_processor_file = cached_file(
@@ -650,6 +657,21 @@ class ProcessorMixin(PushToHubMixin):
subfolder=subfolder, subfolder=subfolder,
_raise_exceptions_for_missing_entries=False, _raise_exceptions_for_missing_entries=False,
) )
resolved_raw_chat_template_file = cached_file(
pretrained_model_name_or_path,
raw_chat_template_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception. # the original exception.
@@ -664,8 +686,11 @@ class ProcessorMixin(PushToHubMixin):
) )
# Add chat template as kwarg before returning because most models don't have processor config # Add chat template as kwarg before returning because most models don't have processor config
chat_template = None if resolved_raw_chat_template_file is not None:
if resolved_chat_template_file is not None: with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader:
chat_template = reader.read()
kwargs["chat_template"] = chat_template
elif resolved_chat_template_file is not None:
with open(resolved_chat_template_file, "r", encoding="utf-8") as reader: with open(resolved_chat_template_file, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
chat_template = json.loads(text)["chat_template"] chat_template = json.loads(text)["chat_template"]
@@ -696,7 +721,7 @@ class ProcessorMixin(PushToHubMixin):
if "chat_template" in processor_dict and processor_dict["chat_template"] is not None: if "chat_template" in processor_dict and processor_dict["chat_template"] is not None:
logger.warning_once( logger.warning_once(
"Chat templates should be in a 'chat_template.json' file but found key='chat_template' " "Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' "
"in the processor's config. Make sure to move your template to its own file." "in the processor's config. Make sure to move your template to its own file."
) )

View File

@@ -145,6 +145,7 @@ AudioInput = Union["np.ndarray", "torch.Tensor", List["np.ndarray"], List["torch
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json" ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
CHAT_TEMPLATE_FILE = "chat_template.jinja"
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
FULL_TOKENIZER_FILE = "tokenizer.json" FULL_TOKENIZER_FILE = "tokenizer.json"
@@ -1941,6 +1942,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"tokenizer_config_file": TOKENIZER_CONFIG_FILE, "tokenizer_config_file": TOKENIZER_CONFIG_FILE,
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
"tokenizer_file": FULL_TOKENIZER_FILE, "tokenizer_file": FULL_TOKENIZER_FILE,
"chat_template_file": CHAT_TEMPLATE_FILE,
} }
vocab_files = {**cls.vocab_files_names, **additional_files_names} vocab_files = {**cls.vocab_files_names, **additional_files_names}
if "tokenizer_file" in vocab_files: if "tokenizer_file" in vocab_files:
@@ -2097,6 +2099,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
config_tokenizer_class = None config_tokenizer_class = None
init_kwargs = init_configuration init_kwargs = init_configuration
# If an independent chat template file exists, it takes priority over template entries in the tokenizer config
chat_template_file = resolved_vocab_files.pop("chat_template_file", None)
if chat_template_file is not None:
with open(chat_template_file) as chat_template_handle:
init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config
if not _is_local: if not _is_local:
if "auto_map" in init_kwargs: if "auto_map" in init_kwargs:
# For backward compatibility with odl format. # For backward compatibility with odl format.
@@ -2396,6 +2404,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
tokenizer_config_file = os.path.join( tokenizer_config_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
) )
chat_template_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE
)
tokenizer_config = copy.deepcopy(self.init_kwargs) tokenizer_config = copy.deepcopy(self.init_kwargs)
@@ -2418,7 +2429,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if isinstance(self.chat_template, dict): if isinstance(self.chat_template, dict):
# Chat template dicts are saved to the config as lists of dicts with fixed key names. # Chat template dicts are saved to the config as lists of dicts with fixed key names.
# They will be reconstructed as a single dict during loading. # They will be reconstructed as a single dict during loading.
# We're trying to discourage chat template dicts, and they are always
# saved in the config, never as single files.
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
elif kwargs.get("save_raw_chat_template", False):
with open(chat_template_file, "w", encoding="utf-8") as f:
f.write(self.chat_template)
logger.info(f"chat template saved in {chat_template_file}")
if "chat_template" in tokenizer_config:
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
else: else:
tokenizer_config["chat_template"] = self.chat_template tokenizer_config["chat_template"] = self.chat_template

View File

@@ -18,6 +18,7 @@ import inspect
import json import json
import random import random
import tempfile import tempfile
from pathlib import Path
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@@ -519,3 +520,27 @@ class ProcessorTesterMixin:
processor.prepare_and_validate_optional_call_args( processor.prepare_and_validate_optional_call_args(
*(f"optional_{i}" for i in range(num_optional_call_args + 1)) *(f"optional_{i}" for i in range(num_optional_call_args + 1))
) )
def test_chat_template_save_loading(self):
processor = self.get_processor()
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
processor.chat_template = "test template"
with tempfile.TemporaryDirectory() as tmpdirname:
processor.save_pretrained(tmpdirname)
self.assertTrue(Path(tmpdirname, "chat_template.json").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file())
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
# When we don't use single-file chat template saving, processor and tokenizer chat templates
# should remain separate
self.assertEqual(getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template)
with tempfile.TemporaryDirectory() as tmpdirname:
processor.save_pretrained(tmpdirname, save_raw_chat_template=True)
self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
# When we save as single files, tokenizers and processors share a chat template, which means
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)

View File

@@ -25,6 +25,7 @@ import traceback
import unittest import unittest
from collections import OrderedDict from collections import OrderedDict
from itertools import takewhile from itertools import takewhile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from parameterized import parameterized from parameterized import parameterized
@@ -1107,13 +1108,29 @@ class TokenizerTesterMixin:
with tempfile.TemporaryDirectory() as tmp_dir_name: with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name) tokenizer.save_pretrained(tmp_dir_name)
tokenizer = tokenizer.from_pretrained(tmp_dir_name) new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False) output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading self.assertEqual(output, expected_output) # Test output is the same after reloading
# Check that no error raised # Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
chat_template_file = Path(tmp_dir_name) / "chat_template.jinja"
self.assertTrue(chat_template_file.is_file())
self.assertEqual(chat_template_file.read_text(), dummy_template)
config_dict = json.loads((Path(tmp_dir_name) / "tokenizer_config.json").read_text())
# Assert the chat template is not in the config when it's saved as a separate file
self.assertNotIn("chat_template", config_dict)
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading
# Check that no error raised
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
@require_jinja @require_jinja
def test_chat_template_batched(self): def test_chat_template_batched(self):
@@ -1526,18 +1543,40 @@ class TokenizerTesterMixin:
tokenizers = self.get_tokenizers() tokenizers = self.get_tokenizers()
for tokenizer in tokenizers: for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2} for save_raw_chat_template in (True, False):
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
with tempfile.TemporaryDirectory() as tmp_dir_name:
# Test that save_raw_chat_template is ignored when there's a dict of multiple templates
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template)
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[
{"name": "template1", "template": "{{'a'}}"},
{"name": "template2", "template": "{{'b'}}"},
],
)
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
# Assert that the serialized list is correctly reconstructed as a single dict
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
@require_jinja
def test_chat_template_file_priority(self):
dummy_template1 = "a"
dummy_template2 = "b"
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
with tempfile.TemporaryDirectory() as tmp_dir_name: with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name) tokenizer.chat_template = dummy_template1
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json"))) tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=False)
# Assert that chat templates are correctly serialized as lists of dictionaries with Path(tmp_dir_name, "chat_template.jinja").open("w") as f:
self.assertEqual( f.write(dummy_template2)
config_dict["chat_template"],
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
)
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name) new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
# Assert that the serialized list is correctly reconstructed as a single dict # Assert the file template clobbers any template in the config
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template) self.assertEqual(new_tokenizer.chat_template, dummy_template2)
def test_number_of_added_tokens(self): def test_number_of_added_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)