Separate chat templates into a single file (#33957)
* Initial draft * Add .jinja file loading for processors * Add processor saving of naked chat template files * make fixup * Add save-load test for tokenizers * Add save-load test for tokenizers * stash commit * Try popping the file * make fixup * Pop the arg correctly * Pop the arg correctly * Add processor test * Fix processor code * stash commit * Processor clobbers child tokenizer's chat template * Processor clobbers child tokenizer's chat template * make fixup * Split processor/tokenizer files to avoid interactions * fix test * Expand processor tests * Rename arg to "save_raw_chat_template" across all classes * Update processor warning * Move templates to single file * Move templates to single file * Improve testing for processor/tokenizer clashes * Improve testing for processor/tokenizer clashes * Extend saving test * Test file priority correctly * make fixup * Don't pop the chat template file before the slow tokenizer gets a look * Remove breakpoint * make fixup * Fix error
This commit is contained in:
@@ -44,7 +44,6 @@ from .tokenization_utils_base import (
|
|||||||
TruncationStrategy,
|
TruncationStrategy,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
CHAT_TEMPLATE_NAME,
|
|
||||||
PROCESSOR_NAME,
|
PROCESSOR_NAME,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
TensorType,
|
TensorType,
|
||||||
@@ -527,12 +526,18 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
# plus we save chat_template in its own file
|
# plus we save chat_template in its own file
|
||||||
output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
|
output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
|
||||||
output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME)
|
output_raw_chat_template_file = os.path.join(save_directory, "chat_template.jinja")
|
||||||
|
output_chat_template_file = os.path.join(save_directory, "chat_template.json")
|
||||||
|
|
||||||
processor_dict = self.to_dict()
|
processor_dict = self.to_dict()
|
||||||
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
|
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
|
||||||
# to avoid serializing chat template in json config file. So let's get it from `self` directly
|
# to avoid serializing chat template in json config file. So let's get it from `self` directly
|
||||||
if self.chat_template is not None:
|
if self.chat_template is not None:
|
||||||
|
if kwargs.get("save_raw_chat_template", False):
|
||||||
|
with open(output_raw_chat_template_file, "w", encoding="utf-8") as writer:
|
||||||
|
writer.write(self.chat_template)
|
||||||
|
logger.info(f"chat template saved in {output_raw_chat_template_file}")
|
||||||
|
else:
|
||||||
chat_template_json_string = (
|
chat_template_json_string = (
|
||||||
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
|
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
|
||||||
)
|
)
|
||||||
@@ -601,21 +606,23 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME)
|
processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME)
|
||||||
chat_template_file = os.path.join(pretrained_model_name_or_path, "chat_template.json")
|
|
||||||
|
|
||||||
if os.path.isfile(pretrained_model_name_or_path):
|
if os.path.isfile(pretrained_model_name_or_path):
|
||||||
resolved_processor_file = pretrained_model_name_or_path
|
resolved_processor_file = pretrained_model_name_or_path
|
||||||
# cant't load chat-template when given a file as pretrained_model_name_or_path
|
# cant't load chat-template when given a file as pretrained_model_name_or_path
|
||||||
resolved_chat_template_file = None
|
resolved_chat_template_file = None
|
||||||
|
resolved_raw_chat_template_file = None
|
||||||
is_local = True
|
is_local = True
|
||||||
elif is_remote_url(pretrained_model_name_or_path):
|
elif is_remote_url(pretrained_model_name_or_path):
|
||||||
processor_file = pretrained_model_name_or_path
|
processor_file = pretrained_model_name_or_path
|
||||||
resolved_processor_file = download_url(pretrained_model_name_or_path)
|
resolved_processor_file = download_url(pretrained_model_name_or_path)
|
||||||
# can't load chat-template when given a file url as pretrained_model_name_or_path
|
# can't load chat-template when given a file url as pretrained_model_name_or_path
|
||||||
resolved_chat_template_file = None
|
resolved_chat_template_file = None
|
||||||
|
resolved_raw_chat_template_file = None
|
||||||
else:
|
else:
|
||||||
processor_file = PROCESSOR_NAME
|
processor_file = PROCESSOR_NAME
|
||||||
chat_template_file = CHAT_TEMPLATE_NAME
|
chat_template_file = "chat_template.json"
|
||||||
|
raw_chat_template_file = "chat_template.jinja"
|
||||||
try:
|
try:
|
||||||
# Load from local folder or from cache or download from model Hub and cache
|
# Load from local folder or from cache or download from model Hub and cache
|
||||||
resolved_processor_file = cached_file(
|
resolved_processor_file = cached_file(
|
||||||
@@ -650,6 +657,21 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
_raise_exceptions_for_missing_entries=False,
|
_raise_exceptions_for_missing_entries=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
resolved_raw_chat_template_file = cached_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
raw_chat_template_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||||
# the original exception.
|
# the original exception.
|
||||||
@@ -664,8 +686,11 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add chat template as kwarg before returning because most models don't have processor config
|
# Add chat template as kwarg before returning because most models don't have processor config
|
||||||
chat_template = None
|
if resolved_raw_chat_template_file is not None:
|
||||||
if resolved_chat_template_file is not None:
|
with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader:
|
||||||
|
chat_template = reader.read()
|
||||||
|
kwargs["chat_template"] = chat_template
|
||||||
|
elif resolved_chat_template_file is not None:
|
||||||
with open(resolved_chat_template_file, "r", encoding="utf-8") as reader:
|
with open(resolved_chat_template_file, "r", encoding="utf-8") as reader:
|
||||||
text = reader.read()
|
text = reader.read()
|
||||||
chat_template = json.loads(text)["chat_template"]
|
chat_template = json.loads(text)["chat_template"]
|
||||||
@@ -696,7 +721,7 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
|
|
||||||
if "chat_template" in processor_dict and processor_dict["chat_template"] is not None:
|
if "chat_template" in processor_dict and processor_dict["chat_template"] is not None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Chat templates should be in a 'chat_template.json' file but found key='chat_template' "
|
"Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' "
|
||||||
"in the processor's config. Make sure to move your template to its own file."
|
"in the processor's config. Make sure to move your template to its own file."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -145,6 +145,7 @@ AudioInput = Union["np.ndarray", "torch.Tensor", List["np.ndarray"], List["torch
|
|||||||
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
|
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
|
||||||
ADDED_TOKENS_FILE = "added_tokens.json"
|
ADDED_TOKENS_FILE = "added_tokens.json"
|
||||||
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
||||||
|
CHAT_TEMPLATE_FILE = "chat_template.jinja"
|
||||||
|
|
||||||
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
|
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
|
||||||
FULL_TOKENIZER_FILE = "tokenizer.json"
|
FULL_TOKENIZER_FILE = "tokenizer.json"
|
||||||
@@ -1941,6 +1942,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||||
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
|
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
|
||||||
"tokenizer_file": FULL_TOKENIZER_FILE,
|
"tokenizer_file": FULL_TOKENIZER_FILE,
|
||||||
|
"chat_template_file": CHAT_TEMPLATE_FILE,
|
||||||
}
|
}
|
||||||
vocab_files = {**cls.vocab_files_names, **additional_files_names}
|
vocab_files = {**cls.vocab_files_names, **additional_files_names}
|
||||||
if "tokenizer_file" in vocab_files:
|
if "tokenizer_file" in vocab_files:
|
||||||
@@ -2097,6 +2099,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
config_tokenizer_class = None
|
config_tokenizer_class = None
|
||||||
init_kwargs = init_configuration
|
init_kwargs = init_configuration
|
||||||
|
|
||||||
|
# If an independent chat template file exists, it takes priority over template entries in the tokenizer config
|
||||||
|
chat_template_file = resolved_vocab_files.pop("chat_template_file", None)
|
||||||
|
if chat_template_file is not None:
|
||||||
|
with open(chat_template_file) as chat_template_handle:
|
||||||
|
init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config
|
||||||
|
|
||||||
if not _is_local:
|
if not _is_local:
|
||||||
if "auto_map" in init_kwargs:
|
if "auto_map" in init_kwargs:
|
||||||
# For backward compatibility with odl format.
|
# For backward compatibility with odl format.
|
||||||
@@ -2396,6 +2404,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
tokenizer_config_file = os.path.join(
|
tokenizer_config_file = os.path.join(
|
||||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
|
||||||
)
|
)
|
||||||
|
chat_template_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer_config = copy.deepcopy(self.init_kwargs)
|
tokenizer_config = copy.deepcopy(self.init_kwargs)
|
||||||
|
|
||||||
@@ -2418,7 +2429,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
if isinstance(self.chat_template, dict):
|
if isinstance(self.chat_template, dict):
|
||||||
# Chat template dicts are saved to the config as lists of dicts with fixed key names.
|
# Chat template dicts are saved to the config as lists of dicts with fixed key names.
|
||||||
# They will be reconstructed as a single dict during loading.
|
# They will be reconstructed as a single dict during loading.
|
||||||
|
# We're trying to discourage chat template dicts, and they are always
|
||||||
|
# saved in the config, never as single files.
|
||||||
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
|
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
|
||||||
|
elif kwargs.get("save_raw_chat_template", False):
|
||||||
|
with open(chat_template_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(self.chat_template)
|
||||||
|
logger.info(f"chat template saved in {chat_template_file}")
|
||||||
|
if "chat_template" in tokenizer_config:
|
||||||
|
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
|
||||||
else:
|
else:
|
||||||
tokenizer_config["chat_template"] = self.chat_template
|
tokenizer_config["chat_template"] = self.chat_template
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -519,3 +520,27 @@ class ProcessorTesterMixin:
|
|||||||
processor.prepare_and_validate_optional_call_args(
|
processor.prepare_and_validate_optional_call_args(
|
||||||
*(f"optional_{i}" for i in range(num_optional_call_args + 1))
|
*(f"optional_{i}" for i in range(num_optional_call_args + 1))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_chat_template_save_loading(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
|
||||||
|
processor.chat_template = "test template"
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
processor.save_pretrained(tmpdirname)
|
||||||
|
self.assertTrue(Path(tmpdirname, "chat_template.json").is_file())
|
||||||
|
self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file())
|
||||||
|
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
|
||||||
|
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
|
||||||
|
# When we don't use single-file chat template saving, processor and tokenizer chat templates
|
||||||
|
# should remain separate
|
||||||
|
self.assertEqual(getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
processor.save_pretrained(tmpdirname, save_raw_chat_template=True)
|
||||||
|
self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file())
|
||||||
|
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
|
||||||
|
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
|
||||||
|
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
|
||||||
|
# When we save as single files, tokenizers and processors share a chat template, which means
|
||||||
|
# the reloaded tokenizer should get the chat template as well
|
||||||
|
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import traceback
|
|||||||
import unittest
|
import unittest
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from itertools import takewhile
|
from itertools import takewhile
|
||||||
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
@@ -1107,13 +1108,29 @@ class TokenizerTesterMixin:
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
tokenizer.save_pretrained(tmp_dir_name)
|
tokenizer.save_pretrained(tmp_dir_name)
|
||||||
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||||
|
|
||||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
|
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||||
# Check that no error raised
|
# Check that no error raised
|
||||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
|
||||||
|
chat_template_file = Path(tmp_dir_name) / "chat_template.jinja"
|
||||||
|
self.assertTrue(chat_template_file.is_file())
|
||||||
|
self.assertEqual(chat_template_file.read_text(), dummy_template)
|
||||||
|
config_dict = json.loads((Path(tmp_dir_name) / "tokenizer_config.json").read_text())
|
||||||
|
# Assert the chat template is not in the config when it's saved as a separate file
|
||||||
|
self.assertNotIn("chat_template", config_dict)
|
||||||
|
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||||
|
|
||||||
|
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||||
|
output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||||
|
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||||
|
# Check that no error raised
|
||||||
|
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||||
|
|
||||||
@require_jinja
|
@require_jinja
|
||||||
def test_chat_template_batched(self):
|
def test_chat_template_batched(self):
|
||||||
@@ -1526,19 +1543,41 @@ class TokenizerTesterMixin:
|
|||||||
tokenizers = self.get_tokenizers()
|
tokenizers = self.get_tokenizers()
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
for save_raw_chat_template in (True, False):
|
||||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
tokenizer.save_pretrained(tmp_dir_name)
|
# Test that save_raw_chat_template is ignored when there's a dict of multiple templates
|
||||||
|
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template)
|
||||||
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||||
# Assert that chat templates are correctly serialized as lists of dictionaries
|
# Assert that chat templates are correctly serialized as lists of dictionaries
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
config_dict["chat_template"],
|
config_dict["chat_template"],
|
||||||
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
|
[
|
||||||
|
{"name": "template1", "template": "{{'a'}}"},
|
||||||
|
{"name": "template2", "template": "{{'b'}}"},
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
|
||||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||||
# Assert that the serialized list is correctly reconstructed as a single dict
|
# Assert that the serialized list is correctly reconstructed as a single dict
|
||||||
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
||||||
|
|
||||||
|
@require_jinja
|
||||||
|
def test_chat_template_file_priority(self):
|
||||||
|
dummy_template1 = "a"
|
||||||
|
dummy_template2 = "b"
|
||||||
|
tokenizers = self.get_tokenizers()
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
tokenizer.chat_template = dummy_template1
|
||||||
|
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=False)
|
||||||
|
with Path(tmp_dir_name, "chat_template.jinja").open("w") as f:
|
||||||
|
f.write(dummy_template2)
|
||||||
|
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||||
|
# Assert the file template clobbers any template in the config
|
||||||
|
self.assertEqual(new_tokenizer.chat_template, dummy_template2)
|
||||||
|
|
||||||
def test_number_of_added_tokens(self):
|
def test_number_of_added_tokens(self):
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
|
|||||||
Reference in New Issue
Block a user