🚨 🚨 Allow saving and loading multiple "raw" chat template files (#36588)

* Add saving in the new format (but no loading yet!)

* Add saving in the new format (but no loading yet!)

* A new approach to template files!

* make fixup

* make fixup, set correct dir

* Some progress but need to rework for cached_file

* Rework loading handling again

* Small fixes

* Looks like it's working now!

* make fixup

* Working!

* make fixup

* make fixup

* Add TODO so I don't miss it

* Cleaner control flow with one less indent

* Copy the new logic to processing_utils as well

* Proper support for dicts of templates

* make fixup

* define the file/dir names in a single place

* Update the processor chat template reload test as well

* Add processor loading of multiple templates

* Flatten correctly to match tokenizers

* Better support when files are empty sometimes

* Stop creating those empty templates

* Revert changes now we don't have empty templates

* Revert changes now we don't have empty templates

* Don't support separate template files on the legacy path

* Rework/simplify loading code

* Make sure it's always a chat_template key in chat_template.json

* Update processor handling of multiple templates

* Add a full save-loading test to the tokenizer tests as well

* Correct un-flattening

* New test was incorrect

* Correct error/offline handling

* Better exception handling

* More error handling cleanup

* Add skips for test failing on main

* Reorder to fix errors

* make fixup

* clarify legacy processor file docs and location

* Update src/transformers/processing_utils.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update src/transformers/processing_utils.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update src/transformers/processing_utils.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update src/transformers/processing_utils.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Rename to _jinja and _legacy

* Stop saving multiple templates in the legacy format

* Cleanup the processing code

* Cleanup the processing code more

* make fixup

* make fixup

* correct reformatting

* Use correct dir name

* Fix import location

* Use save_jinja_files instead of save_raw_chat_template_files

* Correct the test for saving multiple processor templates

* Fix type hint

* Update src/transformers/utils/hub.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Patch llava_onevision test

* Update src/transformers/processing_utils.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Refactor chat template saving out into a separate function

* Update tests for the new default

* Don't do chat template saving logic when chat template isn't there

* Ensure save_jinja_files is propagated to tokenizer correctly

* Trigger tests

* Update more tests to new default

* Trigger tests

---------

Co-authored-by: Lucain <lucainp@gmail.com>
Co-authored-by: Julien Chaumond <julien@huggingface.co>
This commit is contained in:
Matt
2025-04-11 16:37:23 +01:00
committed by GitHub
parent 897874748b
commit bf46e44878
9 changed files with 391 additions and 82 deletions

View File

@@ -528,6 +528,7 @@ class AutoModelTest(unittest.TestCase):
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
@unittest.skip("Failing on main")
def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model.
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")

View File

@@ -291,6 +291,7 @@ class TFAutoModelTest(unittest.TestCase):
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
@unittest.skip("Failing on main")
def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model.
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")

View File

@@ -767,7 +767,7 @@ class ProcessorTesterMixin:
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
processor.chat_template = "test template"
with tempfile.TemporaryDirectory() as tmpdirname:
processor.save_pretrained(tmpdirname)
processor.save_pretrained(tmpdirname, save_jinja_files=False)
self.assertTrue(Path(tmpdirname, "chat_template.json").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file())
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
@@ -777,15 +777,34 @@ class ProcessorTesterMixin:
self.assertEqual(getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template)
with tempfile.TemporaryDirectory() as tmpdirname:
processor.save_pretrained(tmpdirname, save_raw_chat_template=True)
processor.save_pretrained(tmpdirname)
self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
self.assertFalse(Path(tmpdirname, "additional_chat_templates").is_dir())
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
# When we save as single files, tokenizers and processors share a chat template, which means
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
with tempfile.TemporaryDirectory() as tmpdirname:
processor.chat_template = {"default": "a", "secondary": "b"}
processor.save_pretrained(tmpdirname)
self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
self.assertTrue(Path(tmpdirname, "additional_chat_templates").is_dir())
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
# When we save as single files, tokenizers and processors share a chat template, which means
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
with self.assertRaises(ValueError):
# Saving multiple templates in the legacy format is not permitted
with tempfile.TemporaryDirectory() as tmpdirname:
processor.chat_template = {"default": "a", "secondary": "b"}
processor.save_pretrained(tmpdirname, save_jinja_files=False)
@require_torch
def _test_apply_chat_template(
self,

View File

@@ -1151,7 +1151,7 @@ class TokenizerTesterMixin:
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name:
save_files = tokenizer.save_pretrained(tmp_dir_name)
save_files = tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=False)
# Check we aren't saving a chat_template.jinja file
self.assertFalse(any(file.endswith("chat_template.jinja") for file in save_files))
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
@@ -1163,7 +1163,7 @@ class TokenizerTesterMixin:
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name:
save_files = tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True)
save_files = tokenizer.save_pretrained(tmp_dir_name)
# Check we are saving a chat_template.jinja file
self.assertTrue(any(file.endswith("chat_template.jinja") for file in save_files))
chat_template_file = Path(tmp_dir_name) / "chat_template.jinja"
@@ -1180,6 +1180,49 @@ class TokenizerTesterMixin:
# Check that no error raised
new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
@require_jinja
def test_chat_template_save_loading(self):
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
signature = inspect.signature(tokenizer.__init__)
if "chat_template" not in {*signature.parameters.keys()}:
self.skipTest("tokenizer doesn't accept chat templates at input")
tokenizer.chat_template = "test template"
with tempfile.TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
self.assertFalse(Path(tmpdirname, "additional_chat_templates").is_dir())
reloaded_tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
# When we save as single files, tokenizers and tokenizers share a chat template, which means
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_tokenizer.chat_template, reloaded_tokenizer.tokenizer.chat_template)
with tempfile.TemporaryDirectory() as tmpdirname:
tokenizer.chat_template = {"default": "a", "secondary": "b"}
tokenizer.save_pretrained(tmpdirname)
self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
self.assertTrue(Path(tmpdirname, "additional_chat_templates").is_dir())
reloaded_tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
# When we save as single files, tokenizers and tokenizers share a chat template, which means
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_tokenizer.chat_template, reloaded_tokenizer.tokenizer.chat_template)
with tempfile.TemporaryDirectory() as tmpdirname:
tokenizer.chat_template = {"default": "a", "secondary": "b"}
tokenizer.save_pretrained(tmpdirname, save_jinja_files=False)
self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
self.assertFalse(Path(tmpdirname, "additional_chat_templates").is_dir())
reloaded_tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
# When we save as single files, tokenizers and tokenizers share a chat template, which means
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_tokenizer.chat_template, reloaded_tokenizer.tokenizer.chat_template)
@require_jinja
def test_chat_template_batched(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
@@ -1669,21 +1712,29 @@ class TokenizerTesterMixin:
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
for save_raw_chat_template in (True, False):
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
for save_jinja_files in (True, False):
tokenizer.chat_template = {"default": dummy_template_1, "template2": dummy_template_2}
with tempfile.TemporaryDirectory() as tmp_dir_name:
# Test that save_raw_chat_template is ignored when there's a dict of multiple templates
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template)
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[
{"name": "template1", "template": "{{'a'}}"},
{"name": "template2", "template": "{{'b'}}"},
],
)
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
# Test that save_jinja_files is ignored when there's a dict of multiple templates
tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=save_jinja_files)
if save_jinja_files:
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
self.assertNotIn("chat_template", config_dict)
self.assertTrue(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir_name, "additional_chat_templates/template2.jinja"))
)
else:
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[
{"name": "default", "template": "{{'a'}}"},
{"name": "template2", "template": "{{'b'}}"},
],
)
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
# Assert that the serialized list is correctly reconstructed as a single dict
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
@@ -1697,7 +1748,7 @@ class TokenizerTesterMixin:
with self.subTest(f"{tokenizer.__class__.__name__}"):
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.chat_template = dummy_template1
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=False)
tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=False)
with Path(tmp_dir_name, "chat_template.jinja").open("w") as f:
f.write(dummy_template2)
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)